[Mlir-commits] [mlir] [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (PR #95710)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 21 09:30:05 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/95710

>From 0418e51cf33bc59cc6f19ed00edc8c2d62e4d9df Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 15 Jun 2024 10:46:44 -0500
Subject: [PATCH 01/13] implement canonicalizer for batched linalg operations

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  49 ++-----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 121 ++++++++++++++++++
 .../linalg/opdsl/ops/core_named_ops.py        |   5 +
 3 files changed, 138 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index fad234a9dcae9..3cbfb58ed8506 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: reciprocal
-  cpp_class_name: ReciprocalOp
-  doc: |-
-    Applies reciprocal(x) elementwise.
-
-    No numeric casting is performed on the input operand.
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: I
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<() -> ()>
-    - affine_map<() -> ()>
-  iterator_types: []
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: unary
-        fn_name: reciprocal
-        operands:
-        - !ScalarExpression
-          scalar_arg: I
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: round
   cpp_class_name: RoundOp
@@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: erf
-  cpp_class_name: erfOp
+  cpp_class_name: ErfOp
   doc: |-
     Applies erf(x) elementwise.
 
@@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: powf
-  cpp_class_name: PowFOp
+  cpp_class_name: PowfOp
   doc: |-
     Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
 
@@ -1622,6 +1587,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -1692,6 +1659,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -1762,6 +1731,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -2140,6 +2111,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -2208,6 +2181,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..ecd669165efc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -42,6 +43,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <numeric>
 #include <optional>
 
 using namespace mlir;
@@ -578,6 +580,125 @@ class RegionBuilderHelper {
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename BatchOpTy, typename OpTy>
+struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
+  using OpRewritePattern<BatchOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto loc = batchMatmulOp.getLoc();
+    auto inputs = batchMatmulOp.getDpsInputs();
+    auto inits = batchMatmulOp.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp,
+                                         "expected 2 inputs and 1 init");
+    auto lhs = inputs[0];
+    auto rhs = inputs[1];
+    auto init = inits[0];
+
+    auto lhsType = cast<ShapedType>(lhs.getType());
+    auto rhsType = cast<ShapedType>(rhs.getType());
+    auto initType = cast<ShapedType>(init.getType());
+    if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
+        lhsType.getShape()[0] != rhsType.getShape()[0] ||
+        rhsType.getShape()[0] != initType.getShape()[0])
+      return rewriter.notifyMatchFailure(
+          batchMatmulOp, "expected batch sizes of all operands to be same");
+
+    auto results = batchMatmulOp.getResults();
+    if (results.size() > 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp,
+                                         "expected at most one result");
+
+    SmallVector<Type, 1> resultType;
+    if (results.size() == 1) {
+      auto oldResultType = cast<RankedTensorType>(results[0].getType());
+      resultType.push_back(
+          RankedTensorType::get(oldResultType.getShape().drop_front(1),
+                                oldResultType.getElementType()));
+    }
+
+    auto collapseSingletonDim = [&](Value val) -> Value {
+      SmallVector<ReassociationIndices> reassociation({{0, 1}});
+      auto valType = cast<ShapedType>(val.getType());
+      for (auto i = 2; i < valType.getRank(); i++)
+        reassociation.push_back({i});
+      if (isa<RankedTensorType>(valType)) {
+        RankedTensorType collapsedType = RankedTensorType::get(
+            valType.getShape().drop_front(1), valType.getElementType());
+        return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
+                                                        reassociation);
+      }
+      MemRefType collapsedType = MemRefType::get(
+          valType.getShape().drop_front(1), valType.getElementType());
+      return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
+                                                      reassociation);
+    };
+
+    auto collapsedLhs = collapseSingletonDim(lhs);
+    auto collapsedRhs = collapseSingletonDim(rhs);
+    auto collapsedInit = collapseSingletonDim(init);
+
+    auto collapsedOp = rewriter.create<OpTy>(
+        loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+        ValueRange{collapsedInit});
+    for (auto attr : batchMatmulOp->getAttrs()) {
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+        continue;
+      collapsedOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    if (results.size() < 1) {
+      rewriter.replaceOp(batchMatmulOp, collapsedOp);
+    } else {
+      SmallVector<ReassociationIndices> reassociation({{0, 1}});
+      auto resultType = cast<ShapedType>(results[0].getType());
+      for (auto i = 2; i < resultType.getRank(); i++)
+        reassociation.push_back({i});
+      Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
+          loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
+      rewriter.replaceOp(batchMatmulOp, expandedResult);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
+}
+
+void BatchMatmulTransposeAOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+      context);
+}
+
+void BatchMatmulTransposeBOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+      context);
+}
+
+void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
+}
+
+void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 43410aaa6af1b..b4b36ba0bfe51 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -518,6 +518,7 @@ def batch_matmul(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -537,6 +538,7 @@ def batch_matmul_transpose_a(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
@@ -556,6 +558,7 @@ def batch_matmul_transpose_b(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -642,6 +645,7 @@ def batch_matvec(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -660,6 +664,7 @@ def batch_vecmat(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(

>From 02b2ca083d145fc88a9498480e4a831affdebf10 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 16 Jun 2024 12:01:33 -0500
Subject: [PATCH 02/13] add tests

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  34 +++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  12 +-
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 137 +++++++++++++++++-
 3 files changed, 174 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3cbfb58ed8506..41f90483c93b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,40 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: round
   cpp_class_name: RoundOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ecd669165efc7..4e47b6018c445 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -605,16 +605,12 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
     auto lhsType = cast<ShapedType>(lhs.getType());
     auto rhsType = cast<ShapedType>(rhs.getType());
     auto initType = cast<ShapedType>(init.getType());
-    if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
-        lhsType.getShape()[0] != rhsType.getShape()[0] ||
-        rhsType.getShape()[0] != initType.getShape()[0])
-      return rewriter.notifyMatchFailure(
-          batchMatmulOp, "expected batch sizes of all operands to be same");
+    if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
+        initType.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
 
     auto results = batchMatmulOp.getResults();
-    if (results.size() > 1)
-      return rewriter.notifyMatchFailure(batchMatmulOp,
-                                         "expected at most one result");
+    assert(results.size() < 2 && "expected at most one result");
 
     SmallVector<Type, 1> resultType;
     if (results.size() == 1) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..8514bcb089891 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
   return %0 : tensor<2x3xf32>
 }
 
-// ----
+// -----
 
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,3 +1096,138 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+// -----
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+  // CHECK-LABEL: @singleton_batch_matmul_tensor
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-DAG:    %[[C2:.*]] = arith.constant 2
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+  return %1 : tensor<1x?x?xf32>
+}
+
+// -----
+
+func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+  // CHECK-LABEL: @singletone_batch_matmul_memref
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+      outs(%arg2 : memref<1x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singletone_batch_matvec
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singletone_batch_vecmat
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singletone_batchmatmul_transpose_a
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singletone_batchmatmul_transpose_b
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+      outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+  return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+

>From 543b0d643506d12c658e2984d943873ed4c8b78b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 16 Jun 2024 12:12:35 -0500
Subject: [PATCH 03/13] remove unecessary changes

---
 .../mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml        | 1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp                        | 2 --
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 41f90483c93b3..3f0aa33767a75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -309,6 +309,7 @@ metadata: !LinalgOpMetadata
   cpp_class_name: ReciprocalOp
   doc: |-
     Applies reciprocal(x) elementwise.
+
     No numeric casting is performed on the input operand.
 structured_op: !LinalgStructuredOpConfig
   args:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4e47b6018c445..8df33a107c2cb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -43,7 +42,6 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
-#include <numeric>
 #include <optional>
 
 using namespace mlir;

>From 5732d87375c942306ddf5c7a6661b8123f423b1c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 18 Jun 2024 20:28:31 -0500
Subject: [PATCH 04/13] Move patterns to a populate function and implement test
 pass

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  14 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   7 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 115 ---------------
 .../Linalg/Transforms/DropUnitDims.cpp        |  98 +++++++++++++
 .../linalg/opdsl/ops/core_named_ops.py        |   5 -
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 137 +-----------------
 mlir/test/lib/Dialect/Linalg/CMakeLists.txt   |   1 +
 .../TestLinalgRankReduceContractionOps.cpp    |  68 +++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 9 files changed, 179 insertions(+), 268 deletions(-)
 create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3f0aa33767a75..fad234a9dcae9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -516,7 +516,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: erf
-  cpp_class_name: ErfOp
+  cpp_class_name: erfOp
   doc: |-
     Applies erf(x) elementwise.
 
@@ -959,7 +959,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: powf
-  cpp_class_name: PowfOp
+  cpp_class_name: PowFOp
   doc: |-
     Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
 
@@ -1622,8 +1622,6 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
-  defines:
-  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -1694,8 +1692,6 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
-  defines:
-  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -1766,8 +1762,6 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
-  defines:
-  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -2146,8 +2140,6 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
-  defines:
-  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -2216,8 +2208,6 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
-  defines:
-  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..c49383c600a57 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Adds patterns that that reduce the rank of named contraction ops that have
+/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors).  For example a
+/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
+/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8df33a107c2cb..b79afebfa8158 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -578,121 +578,6 @@ class RegionBuilderHelper {
 
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// BatchMatmulOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-template <typename BatchOpTy, typename OpTy>
-struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
-  using OpRewritePattern<BatchOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
-                                PatternRewriter &rewriter) const override {
-
-    auto loc = batchMatmulOp.getLoc();
-    auto inputs = batchMatmulOp.getDpsInputs();
-    auto inits = batchMatmulOp.getDpsInits();
-    if (inputs.size() != 2 || inits.size() != 1)
-      return rewriter.notifyMatchFailure(batchMatmulOp,
-                                         "expected 2 inputs and 1 init");
-    auto lhs = inputs[0];
-    auto rhs = inputs[1];
-    auto init = inits[0];
-
-    auto lhsType = cast<ShapedType>(lhs.getType());
-    auto rhsType = cast<ShapedType>(rhs.getType());
-    auto initType = cast<ShapedType>(init.getType());
-    if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
-        initType.getShape()[0] != 1)
-      return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
-
-    auto results = batchMatmulOp.getResults();
-    assert(results.size() < 2 && "expected at most one result");
-
-    SmallVector<Type, 1> resultType;
-    if (results.size() == 1) {
-      auto oldResultType = cast<RankedTensorType>(results[0].getType());
-      resultType.push_back(
-          RankedTensorType::get(oldResultType.getShape().drop_front(1),
-                                oldResultType.getElementType()));
-    }
-
-    auto collapseSingletonDim = [&](Value val) -> Value {
-      SmallVector<ReassociationIndices> reassociation({{0, 1}});
-      auto valType = cast<ShapedType>(val.getType());
-      for (auto i = 2; i < valType.getRank(); i++)
-        reassociation.push_back({i});
-      if (isa<RankedTensorType>(valType)) {
-        RankedTensorType collapsedType = RankedTensorType::get(
-            valType.getShape().drop_front(1), valType.getElementType());
-        return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
-                                                        reassociation);
-      }
-      MemRefType collapsedType = MemRefType::get(
-          valType.getShape().drop_front(1), valType.getElementType());
-      return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
-                                                      reassociation);
-    };
-
-    auto collapsedLhs = collapseSingletonDim(lhs);
-    auto collapsedRhs = collapseSingletonDim(rhs);
-    auto collapsedInit = collapseSingletonDim(init);
-
-    auto collapsedOp = rewriter.create<OpTy>(
-        loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
-        ValueRange{collapsedInit});
-    for (auto attr : batchMatmulOp->getAttrs()) {
-      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
-        continue;
-      collapsedOp->setAttr(attr.getName(), attr.getValue());
-    }
-
-    if (results.size() < 1) {
-      rewriter.replaceOp(batchMatmulOp, collapsedOp);
-    } else {
-      SmallVector<ReassociationIndices> reassociation({{0, 1}});
-      auto resultType = cast<ShapedType>(results[0].getType());
-      for (auto i = 2; i < resultType.getRank(); i++)
-        reassociation.push_back({i});
-      Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
-          loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
-      rewriter.replaceOp(batchMatmulOp, expandedResult);
-    }
-
-    return success();
-  }
-};
-
-} // namespace
-
-void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                MLIRContext *context) {
-  results.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
-}
-
-void BatchMatmulTransposeAOp::getCanonicalizationPatterns(
-    RewritePatternSet &results, MLIRContext *context) {
-  results.add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-      context);
-}
-
-void BatchMatmulTransposeBOp::getCanonicalizationPatterns(
-    RewritePatternSet &results, MLIRContext *context) {
-  results.add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-      context);
-}
-
-void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                MLIRContext *context) {
-  results.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
-}
-
-void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                MLIRContext *context) {
-  results.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c0829397f1f85..9248710d5afc9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -812,6 +812,103 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
   patterns.add<MoveInitOperandsToInput>(patterns.getContext());
 }
 
+namespace {
+
+template <typename BatchOpTy, typename OpTy>
+struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
+  using OpRewritePattern<BatchOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto loc = batchMatmulOp.getLoc();
+    auto inputs = batchMatmulOp.getDpsInputs();
+    auto inits = batchMatmulOp.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp,
+                                         "expected 2 inputs and 1 init");
+    auto lhs = inputs[0];
+    auto rhs = inputs[1];
+    auto init = inits[0];
+
+    auto lhsType = cast<ShapedType>(lhs.getType());
+    auto rhsType = cast<ShapedType>(rhs.getType());
+    auto initType = cast<ShapedType>(init.getType());
+    if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
+        initType.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
+
+    auto results = batchMatmulOp.getResults();
+    assert(results.size() < 2 && "expected at most one result");
+
+    SmallVector<Type, 1> resultType;
+    if (results.size() == 1) {
+      auto oldResultType = cast<RankedTensorType>(results[0].getType());
+      resultType.push_back(
+          RankedTensorType::get(oldResultType.getShape().drop_front(1),
+                                oldResultType.getElementType()));
+    }
+
+    auto collapseSingletonDim = [&](Value val) -> Value {
+      SmallVector<ReassociationIndices> reassociation({{0, 1}});
+      auto valType = cast<ShapedType>(val.getType());
+      for (auto i = 2; i < valType.getRank(); i++)
+        reassociation.push_back({i});
+      if (isa<RankedTensorType>(valType)) {
+        RankedTensorType collapsedType = RankedTensorType::get(
+            valType.getShape().drop_front(1), valType.getElementType());
+        return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
+                                                        reassociation);
+      }
+      MemRefType collapsedType = MemRefType::get(
+          valType.getShape().drop_front(1), valType.getElementType());
+      return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
+                                                      reassociation);
+    };
+
+    auto collapsedLhs = collapseSingletonDim(lhs);
+    auto collapsedRhs = collapseSingletonDim(rhs);
+    auto collapsedInit = collapseSingletonDim(init);
+
+    auto collapsedOp = rewriter.create<OpTy>(
+        loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+        ValueRange{collapsedInit});
+    for (auto attr : batchMatmulOp->getAttrs()) {
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+        continue;
+      collapsedOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    if (results.size() < 1) {
+      rewriter.replaceOp(batchMatmulOp, collapsedOp);
+    } else {
+      SmallVector<ReassociationIndices> reassociation({{0, 1}});
+      auto resultType = cast<ShapedType>(results[0].getType());
+      for (auto i = 2; i < resultType.getRank(); i++)
+        reassociation.push_back({i});
+      Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
+          loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
+      rewriter.replaceOp(batchMatmulOp, expandedResult);
+    }
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateContractionOpRankReducingPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
+  patterns
+      .add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+          context);
+  patterns
+      .add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+          context);
+  patterns.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
+  patterns.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+}
+
 namespace {
 /// Pass that removes unit-extent dims within generic ops.
 struct LinalgFoldUnitExtentDimsPass
@@ -833,4 +930,5 @@ struct LinalgFoldUnitExtentDimsPass
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
+
 } // namespace
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b4b36ba0bfe51..43410aaa6af1b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -518,7 +518,6 @@ def batch_matmul(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
-    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -538,7 +537,6 @@ def batch_matmul_transpose_a(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
-    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
@@ -558,7 +556,6 @@ def batch_matmul_transpose_b(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
-    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -645,7 +642,6 @@ def batch_matvec(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
-    defines(Canonicalizer)
     domain(D.b, D.m, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -664,7 +660,6 @@ def batch_vecmat(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
-    defines(Canonicalizer)
     domain(D.b, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 8514bcb089891..928030a81dc02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
   return %0 : tensor<2x3xf32>
 }
 
-// -----
+// ----
 
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,138 +1096,3 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
-// -----
-
-func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
-  // CHECK-LABEL: @singleton_batch_matmul_tensor
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
-  //  CHECK-DAG:    %[[C2:.*]] = arith.constant 2
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
-  //  CHECK-NEXT:   %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
-  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
-  //  CHECK-NEXT:   return %[[RES]]
-  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
-      outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
-  return %1 : tensor<1x?x?xf32>
-}
-
-// -----
-
-func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
-  // CHECK-LABEL: @singletone_batch_matmul_memref
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:    linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
-  //  CHECK-NEXT:   return
-  linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
-      outs(%arg2 : memref<1x?x?xf32>)
-  return
-}
-
-// -----
-
-func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
-  // CHECK-LABEL: @singletone_batch_matvec
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
-  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
-  //  CHECK-NEXT:   return %[[RES]]
-  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
-      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
-  return %1 : tensor<1x?xf32>
-}
-
-// -----
-
-func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
-  // CHECK-LABEL: @singletone_batch_vecmat
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
-  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
-  //  CHECK-NEXT:   return %[[RES]]
-  %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
-      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
-  return %1 : tensor<1x?xf32>
-}
-
-// -----
-
-func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
-  // CHECK-LABEL: @singletone_batchmatmul_transpose_a
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:    linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
-  //  CHECK-NEXT:   return
-  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
-  return
-}
-
-// -----
-
-func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
-  // CHECK-LABEL: @singletone_batchmatmul_transpose_b
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
-  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:    linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
-  //  CHECK-NEXT:   return
-  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
-  return
-}
-
-// -----
-
-func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
-  // CHECK-LABEL: @nonsingleton_batch_matmul
-  // CHECK-NOT:   collapse_shape
-  // CHECK:       linalg.batch_matmul
-  // CHECK-NOT:   expand_shape
-  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
-      outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
-  return %1 : tensor<2x?x?xf32>
-}
-
-// -----
-
-func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
-  // CHECK-NOT:   collapse_shape
-  // CHECK:       linalg.batch_matmul
-  // CHECK-NOT:   expand_shape
-  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-      outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-  return %1 : tensor<?x?x?xf32>
-}
-
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index b28f2b3564662..283e426b4e594 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgDropUnitDims.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
+  TestLinalgRankReduceContractionOps.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
 
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
new file mode 100644
index 0000000000000..5ca27be30a687
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
@@ -0,0 +1,68 @@
+//===- TestLinalgRankReduceContractionOps.cpp - Test Linalg rank reduce
+//contractions ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing rank reduing patterns for named
+// contraction ops with unit dims.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLinalgRankReduceContractionOps
+    : public PassWrapper<TestLinalgRankReduceContractionOps,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestLinalgRankReduceContractionOps)
+
+  TestLinalgRankReduceContractionOps() = default;
+  TestLinalgRankReduceContractionOps(
+      const TestLinalgRankReduceContractionOps &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect, linalg::LinalgDialect,
+                    memref::MemRefDialect, tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-linalg-rank-reduce-contraction-ops";
+  }
+  StringRef getDescription() const final {
+    return "Test Linalg rank reduce contraction ops with unit dims";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    func::FuncOp funcOp = this->getOperation();
+
+    RewritePatternSet patterns(context);
+    linalg::populateContractionOpRankReducingPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                            std::move(patterns))))
+      return signalPassFailure();
+    return;
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgRankReduceContractionOps() {
+  PassRegistration<TestLinalgRankReduceContractionOps>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0e8b161d51345..d4ea7a9cae0d2 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestLinalgDecomposeOps();
 void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
+void registerTestLinalgRankReduceContractionOps();
 void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
@@ -235,6 +236,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgDropUnitDims();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
+  mlir::test::registerTestLinalgRankReduceContractionOps();
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();

>From 28078405788b799cf64bcf5a7a4059c0eb739875 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 10:24:39 -0500
Subject: [PATCH 05/13] refactor common logic into abstract base class

---
 .../Linalg/Transforms/DropUnitDims.cpp        | 230 +++++++++++++-----
 1 file changed, 166 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9248710d5afc9..07b0cdea40c92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -814,10 +814,66 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
 
 namespace {
 
-template <typename BatchOpTy, typename OpTy>
-struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
-  using OpRewritePattern<BatchOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+static SmallVector<ReassociationIndices>
+getReassociationsForTrailingDims(int64_t rank) {
+  SmallVector<ReassociationIndices> reassociation(rank - 1, {});
+  if (rank > 1) {
+    reassociation[rank - 2] =
+        (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
+    for (int64_t i = 0; i < rank - 2; i++)
+      reassociation[i] = {i};
+  }
+  return reassociation;
+}
+
+static SmallVector<ReassociationIndices>
+getReassociationsForLeadingDims(int64_t rank) {
+  SmallVector<ReassociationIndices> reassociation(rank - 1, {});
+  if (rank > 1) {
+    reassociation[0] =
+        (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
+    for (int64_t i = 1; i < rank - 1; i++)
+      reassociation[i] = {i + rank - 2};
+  }
+  return reassociation;
+}
+
+static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) {
+  auto valType = cast<ShapedType>(val.getType());
+  return collapseValue(
+      rewriter, val.getLoc(), val, valType.getShape().drop_front(1),
+      getReassociationsForLeadingDims(valType.getRank()),
+      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
+                                          Value val) {
+  auto valType = cast<ShapedType>(val.getType());
+  return collapseValue(
+      rewriter, val.getLoc(), val, valType.getShape().drop_back(1),
+      getReassociationsForTrailingDims(valType.getRank()),
+      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+static Value expandLeadingSingletonDim(PatternRewriter &rewriter, Value val,
+                                       RankedTensorType expandedType) {
+  return rewriter.create<tensor::ExpandShapeOp>(
+      val.getLoc(), expandedType, val,
+      getReassociationsForLeadingDims(expandedType.getRank()));
+}
+
+static Value expandTrailingSingletonDim(PatternRewriter &rewriter, Value val,
+                                        RankedTensorType expandedType) {
+  return rewriter.create<tensor::ExpandShapeOp>(
+      val.getLoc(), expandedType, val,
+      getReassociationsForTrailingDims(expandedType.getRank()));
+}
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+  using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromOpTy batchMatmulOp,
                                 PatternRewriter &rewriter) const override {
 
     auto loc = batchMatmulOp.getLoc();
@@ -830,47 +886,19 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
     auto rhs = inputs[1];
     auto init = inits[0];
 
-    auto lhsType = cast<ShapedType>(lhs.getType());
-    auto rhsType = cast<ShapedType>(rhs.getType());
-    auto initType = cast<ShapedType>(init.getType());
-    if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
-        initType.getShape()[0] != 1)
-      return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
-
-    auto results = batchMatmulOp.getResults();
-    assert(results.size() < 2 && "expected at most one result");
-
-    SmallVector<Type, 1> resultType;
-    if (results.size() == 1) {
-      auto oldResultType = cast<RankedTensorType>(results[0].getType());
-      resultType.push_back(
-          RankedTensorType::get(oldResultType.getShape().drop_front(1),
-                                oldResultType.getElementType()));
-    }
-
-    auto collapseSingletonDim = [&](Value val) -> Value {
-      SmallVector<ReassociationIndices> reassociation({{0, 1}});
-      auto valType = cast<ShapedType>(val.getType());
-      for (auto i = 2; i < valType.getRank(); i++)
-        reassociation.push_back({i});
-      if (isa<RankedTensorType>(valType)) {
-        RankedTensorType collapsedType = RankedTensorType::get(
-            valType.getShape().drop_front(1), valType.getElementType());
-        return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
-                                                        reassociation);
-      }
-      MemRefType collapsedType = MemRefType::get(
-          valType.getShape().drop_front(1), valType.getElementType());
-      return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
-                                                      reassociation);
-    };
-
-    auto collapsedLhs = collapseSingletonDim(lhs);
-    auto collapsedRhs = collapseSingletonDim(rhs);
-    auto collapsedInit = collapseSingletonDim(init);
-
-    auto collapsedOp = rewriter.create<OpTy>(
-        loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+    if (!checkTypes(lhs, rhs, init))
+      return rewriter.notifyMatchFailure(batchMatmulOp,
+                                         "no reducable dims found");
+
+    auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init);
+    auto collapsedLhs = collapsedOperands[0];
+    auto collapsedRhs = collapsedOperands[1];
+    auto collapsedInit = collapsedOperands[2];
+    SmallVector<Type, 1> collapsedResultTy;
+    if (isa<RankedTensorType>(collapsedInit.getType()))
+      collapsedResultTy.push_back(collapsedInit.getType());
+    auto collapsedOp = rewriter.create<ToOpTy>(
+        loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
         ValueRange{collapsedInit});
     for (auto attr : batchMatmulOp->getAttrs()) {
       if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
@@ -878,35 +906,109 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
       collapsedOp->setAttr(attr.getName(), attr.getValue());
     }
 
-    if (results.size() < 1) {
+    auto results = batchMatmulOp.getResults();
+    assert(results.size() < 2 && "expected at most one result");
+    if (results.size() < 1)
       rewriter.replaceOp(batchMatmulOp, collapsedOp);
-    } else {
-      SmallVector<ReassociationIndices> reassociation({{0, 1}});
-      auto resultType = cast<ShapedType>(results[0].getType());
-      for (auto i = 2; i < resultType.getRank(); i++)
-        reassociation.push_back({i});
-      Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
-          loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
-      rewriter.replaceOp(batchMatmulOp, expandedResult);
-    }
+    else
+      rewriter.replaceOp(
+          batchMatmulOp,
+          expandResult(rewriter, collapsedOp.getResultTensors()[0],
+                       cast<RankedTensorType>(results[0].getType())));
 
     return success();
   }
+
+  virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0;
+  virtual SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter,
+                                                 Value lhs, Value rhs,
+                                                 Value init) const = 0;
+  virtual Value expandResult(PatternRewriter &rewriter, Value result,
+                             RankedTensorType expandedType) const = 0;
+};
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+  bool checkTypes(Value lhs, Value rhs, Value init) const override {
+    auto lhsType = cast<ShapedType>(lhs.getType());
+    auto rhsType = cast<ShapedType>(rhs.getType());
+    auto initType = cast<ShapedType>(init.getType());
+    return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 &&
+           initType.getShape()[0] == 1;
+  }
+
+  SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
+                                         Value rhs, Value init) const override {
+    auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
+    auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs);
+    auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
+    return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
+  }
+  Value expandResult(PatternRewriter &rewriter, Value result,
+                     RankedTensorType expandedType) const override {
+    return expandLeadingSingletonDim(rewriter, result, expandedType);
+  }
+};
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
+  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+  static bool constexpr reduceLeading =
+      (std::is_same<FromOpTy, MatmulOp>::value &&
+       std::is_same<ToOpTy, VecmatOp>::value) ||
+      (std::is_same<FromOpTy, MatvecOp>::value &&
+       std::is_same<ToOpTy, DotOp>::value);
+
+  bool checkTypes(Value lhs, Value rhs, Value init) const override {
+    auto lhsType = cast<ShapedType>(lhs.getType());
+    auto rhsType = cast<ShapedType>(rhs.getType());
+    auto initType = cast<ShapedType>(init.getType());
+    if (reduceLeading)
+      return lhsType.getShape()[0] == 1 && initType.getShape()[0] == 1;
+    else
+      return rhsType.getShape().back() == 1 && initType.getShape().back() == 1;
+  }
+
+  SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
+                                         Value rhs, Value init) const override {
+    if (reduceLeading) {
+      auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
+      auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
+      return SmallVector<Value, 3>{collapsedLhs, rhs, collapsedInit};
+    } else {
+      auto collapsedRhs = collapseTrailingSingletonDim(rewriter, rhs);
+      auto collapsedInit = collapseTrailingSingletonDim(rewriter, init);
+      return SmallVector<Value, 3>{lhs, collapsedRhs, collapsedInit};
+    }
+  }
+  Value expandResult(PatternRewriter &rewriter, Value result,
+                     RankedTensorType expandedType) const override {
+    if (reduceLeading)
+      return expandLeadingSingletonDim(rewriter, result, expandedType);
+    else
+      return expandTrailingSingletonDim(rewriter, result, expandedType);
+  }
 };
+
 } // namespace
 
 void mlir::linalg::populateContractionOpRankReducingPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
-  patterns
-      .add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-          context);
-  patterns
-      .add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-          context);
-  patterns.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
-  patterns.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+  patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
+  patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+      context);
+  patterns.add<RankReduceBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+      context);
+  patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
+  patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
+  patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
+  patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
 }
 
 namespace {

>From 679192b56ac231fce54824d0189a52f39ea2a63b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 17:07:46 -0500
Subject: [PATCH 06/13] add regression test

---
 .../Linalg/rank-reduce-contraction-ops.mlir   | 197 ++++++++++++++++++
 1 file changed, 197 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir

diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
new file mode 100644
index 0000000000000..279a1d52ae72b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -0,0 +1,197 @@
+//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+  // CHECK-LABEL: @singleton_batch_matmul_tensor
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-DAG:    %[[C2:.*]] = arith.constant 2
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+  return %1 : tensor<1x?x?xf32>
+}
+
+// -----
+
+func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+  // CHECK-LABEL: @singleton_batch_matmul_memref
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+      outs(%arg2 : memref<1x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @singleton_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singleton_batch_matvec
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singleton_batch_vecmat
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singleton_batchmatmul_transpose_a
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singleton_batchmatmul_transpose_b
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @matmul_to_vecmat(%arg0: memref<1x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<1x?xf32>) {
+  // CHECK-LABEL: @matmul_to_vecmat
+  // CHECK: linalg.vecmat
+    linalg.matmul ins(%arg0, %arg1: memref<1x?xf32>, memref<?x?xf32>) outs(%arg2: memref<1x?xf32>)
+    return
+}
+
+// -----
+
+func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) {
+  // CHECK-LABEL: @batch_matmul_to_vecmat
+  // CHECK: linalg.vecmat
+    linalg.batch_matmul ins(%arg0, %arg1: memref<1x1x?xf32>, memref<1x?x?xf32>) outs(%arg2: memref<1x1x?xf32>)
+    return
+}
+
+// -----
+
+func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref<?xf32>, %arg2: memref<1xf32>) {
+  // CHECK-LABEL: @matvec_to_dot
+  // CHECK: linalg.dot
+    linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref<?xf32>) outs(%arg2: memref<1xf32>)
+    return
+}
+
+// -----
+
+func.func @vecmat_to_dot(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<1xf32>) {
+  // CHECK-LABEL: @vecmat_to_dot
+  // CHECK: linalg.dot
+    linalg.vecmat ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<1xf32>)
+    return
+}
+
+// -----
+
+func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> {
+  // CHECK-LABEL: @matvec_to_dot_tensor
+  // CHECK: linalg.dot
+    %0 = linalg.matvec ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?xf32>) outs(%arg2: tensor<1xf32>) -> tensor<1xf32>
+    return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
+  // CHECK-LABEL: @matmul_to_matvec_tensor
+  // CHECK: linalg.matvec
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
+    return %0 : tensor<?x1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+  // CHECK-LABEL: @matmul_to_matvec
+  // CHECK: linalg.matvec
+    linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
+    return
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+      outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+  return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}

>From ce02e9d206cf718927854338636b6f357c62101c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 18:13:39 -0500
Subject: [PATCH 07/13] flesh out some tests

---
 .../Linalg/rank-reduce-contraction-ops.mlir   | 69 ++++++++++++-------
 1 file changed, 46 insertions(+), 23 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 279a1d52ae72b..79003670d2726 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -111,15 +111,51 @@ func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: me
 
 // -----
 
-func.func @matmul_to_vecmat(%arg0: memref<1x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<1x?xf32>) {
-  // CHECK-LABEL: @matmul_to_vecmat
-  // CHECK: linalg.vecmat
-    linalg.matmul ins(%arg0, %arg1: memref<1x?xf32>, memref<?x?xf32>) outs(%arg2: memref<1x?xf32>)
+func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
+  // CHECK-LABEL: @matmul_to_matvec_tensor
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+  //  CHECK-DAG:    %[[C0:.*]] = arith.constant 0
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
+  //  CHECK-NEXT:   return %[[RES]]
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
+    return %0 : tensor<?x1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+  // CHECK-LABEL: @matmul_to_matvec
+  // CHECK: linalg.matvec
+    linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
     return
 }
 
 // -----
 
+func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @matmul_to_vecmat
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
+    return %0 : tensor<1x?xf32>
+}
+
+// -----
+
 func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) {
   // CHECK-LABEL: @batch_matmul_to_vecmat
   // CHECK: linalg.vecmat
@@ -131,7 +167,12 @@ func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?x
 
 func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref<?xf32>, %arg2: memref<1xf32>) {
   // CHECK-LABEL: @matvec_to_dot
-  // CHECK: linalg.dot
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] []
+  //  CHECK-NEXT:   linalg.dot ins(%[[COLLAPSED_LHS]], %[[RHS]] : memref<?xf32>, memref<?xf32>) outs(%[[COLLAPSED_INIT]] : memref<f32>)
     linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref<?xf32>) outs(%arg2: memref<1xf32>)
     return
 }
@@ -156,24 +197,6 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
 
 // -----
 
-func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
-  // CHECK-LABEL: @matmul_to_matvec_tensor
-  // CHECK: linalg.matvec
-    %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
-    return %0 : tensor<?x1xf32>
-}
-
-// -----
-
-func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
-  // CHECK-LABEL: @matmul_to_matvec
-  // CHECK: linalg.matvec
-    linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
-    return
-}
-
-// -----
-
 func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
   // CHECK-LABEL: @nonsingleton_batch_matmul
   // CHECK-NOT:   collapse_shape

>From 9b98efdf911f806f3ae15f2f9e74d99af1d8775e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 20:13:19 -0500
Subject: [PATCH 08/13] support transpose matmul conversion

---
 .../Linalg/Transforms/DropUnitDims.cpp        | 119 ++++++++++--------
 .../Linalg/rank-reduce-contraction-ops.mlir   |  18 +++
 2 files changed, 84 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 07b0cdea40c92..d9230c6127e00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -812,6 +812,30 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
   patterns.add<MoveInitOperandsToInput>(patterns.getContext());
 }
 
+namespace {
+/// Pass that removes unit-extent dims within generic ops.
+struct LinalgFoldUnitExtentDimsPass
+    : public impl::LinalgFoldUnitExtentDimsPassBase<
+          LinalgFoldUnitExtentDimsPass> {
+  using impl::LinalgFoldUnitExtentDimsPassBase<
+      LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *context = op->getContext();
+    RewritePatternSet patterns(context);
+    ControlDropUnitDims options;
+    if (useRankReducingSlices) {
+      options.rankReductionStrategy = linalg::ControlDropUnitDims::
+          RankReductionStrategy::ExtractInsertSlice;
+    }
+    linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
+    populateMoveInitOperandsToInputPattern(patterns);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+
+} // namespace
+
 namespace {
 
 static SmallVector<ReassociationIndices>
@@ -855,20 +879,6 @@ static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
       ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
 }
 
-static Value expandLeadingSingletonDim(PatternRewriter &rewriter, Value val,
-                                       RankedTensorType expandedType) {
-  return rewriter.create<tensor::ExpandShapeOp>(
-      val.getLoc(), expandedType, val,
-      getReassociationsForLeadingDims(expandedType.getRank()));
-}
-
-static Value expandTrailingSingletonDim(PatternRewriter &rewriter, Value val,
-                                        RankedTensorType expandedType) {
-  return rewriter.create<tensor::ExpandShapeOp>(
-      val.getLoc(), expandedType, val,
-      getReassociationsForTrailingDims(expandedType.getRank()));
-}
-
 template <typename FromOpTy, typename ToOpTy>
 struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
   using OpRewritePattern<FromOpTy>::OpRewritePattern;
@@ -948,7 +958,9 @@ struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
   }
   Value expandResult(PatternRewriter &rewriter, Value result,
                      RankedTensorType expandedType) const override {
-    return expandLeadingSingletonDim(rewriter, result, expandedType);
+  return rewriter.create<tensor::ExpandShapeOp>(
+      result.getLoc(), expandedType, result,
+      getReassociationsForLeadingDims(expandedType.getRank()));
   }
 };
 
@@ -956,9 +968,15 @@ template <typename FromOpTy, typename ToOpTy>
 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
 
-  static bool constexpr reduceLeading =
+  static bool constexpr isTranspose =
+      std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
+      std::is_same<FromOpTy, MatmulTransposeBOp>::value;
+
+  static bool constexpr reduceLeft =
       (std::is_same<FromOpTy, MatmulOp>::value &&
        std::is_same<ToOpTy, VecmatOp>::value) ||
+      (std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
+       std::is_same<ToOpTy, VecmatOp>::value) ||
       (std::is_same<FromOpTy, MatvecOp>::value &&
        std::is_same<ToOpTy, DotOp>::value);
 
@@ -966,30 +984,47 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
     auto lhsType = cast<ShapedType>(lhs.getType());
     auto rhsType = cast<ShapedType>(rhs.getType());
     auto initType = cast<ShapedType>(init.getType());
-    if (reduceLeading)
-      return lhsType.getShape()[0] == 1 && initType.getShape()[0] == 1;
+    int constexpr offset = (int)isTranspose;
+    if (reduceLeft)
+      return lhsType.getShape().begin()[offset] == 1 &&
+             initType.getShape().begin()[offset] == 1;
     else
-      return rhsType.getShape().back() == 1 && initType.getShape().back() == 1;
+      return rhsType.getShape().rbegin()[offset] == 1 &&
+             initType.getShape().rbegin()[offset] == 1;
   }
 
   SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
                                          Value rhs, Value init) const override {
-    if (reduceLeading) {
-      auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
-      auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
-      return SmallVector<Value, 3>{collapsedLhs, rhs, collapsedInit};
+    if (reduceLeft) {
+      if (isTranspose) {
+        lhs = collapseTrailingSingletonDim(rewriter, lhs);
+        init = collapseTrailingSingletonDim(rewriter, init);
+      } else {
+        lhs = collapseLeadingSingletonDim(rewriter, lhs);
+        init = collapseLeadingSingletonDim(rewriter, init);
+      }
     } else {
-      auto collapsedRhs = collapseTrailingSingletonDim(rewriter, rhs);
-      auto collapsedInit = collapseTrailingSingletonDim(rewriter, init);
-      return SmallVector<Value, 3>{lhs, collapsedRhs, collapsedInit};
+      if (isTranspose) {
+        rhs = collapseLeadingSingletonDim(rewriter, rhs);
+        init = collapseLeadingSingletonDim(rewriter, init);
+      } else {
+        rhs = collapseTrailingSingletonDim(rewriter, rhs);
+        init = collapseTrailingSingletonDim(rewriter, init);
+      }
     }
+    return SmallVector<Value, 3>{lhs, rhs, init};
   }
+
   Value expandResult(PatternRewriter &rewriter, Value result,
                      RankedTensorType expandedType) const override {
-    if (reduceLeading)
-      return expandLeadingSingletonDim(rewriter, result, expandedType);
+    if (reduceLeft)
+      return rewriter.create<tensor::ExpandShapeOp>(
+          result.getLoc(), expandedType, result,
+          getReassociationsForLeadingDims(expandedType.getRank()));
     else
-      return expandTrailingSingletonDim(rewriter, result, expandedType);
+      return rewriter.create<tensor::ExpandShapeOp>(
+          result.getLoc(), expandedType, result,
+          getReassociationsForTrailingDims(expandedType.getRank()));
   }
 };
 
@@ -1007,30 +1042,8 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
   patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
   patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
 }
-
-namespace {
-/// Pass that removes unit-extent dims within generic ops.
-struct LinalgFoldUnitExtentDimsPass
-    : public impl::LinalgFoldUnitExtentDimsPassBase<
-          LinalgFoldUnitExtentDimsPass> {
-  using impl::LinalgFoldUnitExtentDimsPassBase<
-      LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    MLIRContext *context = op->getContext();
-    RewritePatternSet patterns(context);
-    ControlDropUnitDims options;
-    if (useRankReducingSlices) {
-      options.rankReductionStrategy = linalg::ControlDropUnitDims::
-          RankReductionStrategy::ExtractInsertSlice;
-    }
-    linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
-    populateMoveInitOperandsToInputPattern(patterns);
-    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
-  }
-};
-
-} // namespace
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 79003670d2726..0548f3f860a89 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -197,6 +197,24 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
 
 // -----
 
+func.func @matmul_transpose_a_to_vecmat(%arg0: memref<?x1xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x1xf32>) {
+  // CHECK-LABEL: @matmul_transpose_a_to_vecmat
+  // CHECK: linalg.vecmat
+    linalg.matmul_transpose_a ins(%arg0, %arg1: memref<?x1xf32>, memref<?x?xf32>) outs(%arg2: memref<?x1xf32>)
+    return
+}
+
+// -----
+
+func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
+  // CHECK-LABEL: @matmul_transpose_b_to_matvec
+  // CHECK: linalg.matvec
+    linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<1x?xf32>)
+    return
+}
+
+// -----
+
 func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
   // CHECK-LABEL: @nonsingleton_batch_matmul
   // CHECK-NOT:   collapse_shape

>From cbf8eddc377b7970cd12846a8d95ea740b5e3bf1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 20:28:48 -0500
Subject: [PATCH 09/13] const conditionals

---
 mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d9230c6127e00..327c46b965c87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -985,7 +985,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
     auto rhsType = cast<ShapedType>(rhs.getType());
     auto initType = cast<ShapedType>(init.getType());
     int constexpr offset = (int)isTranspose;
-    if (reduceLeft)
+    if constexpr (reduceLeft)
       return lhsType.getShape().begin()[offset] == 1 &&
              initType.getShape().begin()[offset] == 1;
     else
@@ -995,8 +995,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
 
   SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
                                          Value rhs, Value init) const override {
-    if (reduceLeft) {
-      if (isTranspose) {
+    if constexpr (reduceLeft) {
+      if constexpr (isTranspose) {
         lhs = collapseTrailingSingletonDim(rewriter, lhs);
         init = collapseTrailingSingletonDim(rewriter, init);
       } else {
@@ -1004,7 +1004,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
         init = collapseLeadingSingletonDim(rewriter, init);
       }
     } else {
-      if (isTranspose) {
+      if constexpr (isTranspose) {
         rhs = collapseLeadingSingletonDim(rewriter, rhs);
         init = collapseLeadingSingletonDim(rewriter, init);
       } else {
@@ -1017,7 +1017,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
 
   Value expandResult(PatternRewriter &rewriter, Value result,
                      RankedTensorType expandedType) const override {
-    if (reduceLeft)
+    if constexpr (reduceLeft)
       return rewriter.create<tensor::ExpandShapeOp>(
           result.getLoc(), expandedType, result,
           getReassociationsForLeadingDims(expandedType.getRank()));

>From 29ac128442379767bca2583409f58bd605bb06d1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 20 Jun 2024 14:19:38 -0500
Subject: [PATCH 10/13] refactor and add more patterns

---
 .../Linalg/Transforms/DropUnitDims.cpp        | 143 +++++++++++-------
 1 file changed, 85 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 327c46b965c87..26a85e8225b24 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -839,43 +839,31 @@ struct LinalgFoldUnitExtentDimsPass
 namespace {
 
 static SmallVector<ReassociationIndices>
-getReassociationsForTrailingDims(int64_t rank) {
-  SmallVector<ReassociationIndices> reassociation(rank - 1, {});
-  if (rank > 1) {
-    reassociation[rank - 2] =
-        (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
-    for (int64_t i = 0; i < rank - 2; i++)
-      reassociation[i] = {i};
-  }
-  return reassociation;
-}
-
-static SmallVector<ReassociationIndices>
-getReassociationsForLeadingDims(int64_t rank) {
-  SmallVector<ReassociationIndices> reassociation(rank - 1, {});
-  if (rank > 1) {
-    reassociation[0] =
-        (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
-    for (int64_t i = 1; i < rank - 1; i++)
-      reassociation[i] = {i + rank - 2};
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos,
+                                bool fromRight = false) {
+  SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+  if (rank > 2) {
+    int64_t offsetPos = pos - (int64_t)fromRight;
+    for (int64_t i = 0; i < rank - 1; i++) {
+      if (i == offsetPos)
+        reassociation[i] = ReassociationIndices{i, i + 1};
+      else if (i < offsetPos)
+        reassociation[i] = ReassociationIndices{i};
+      else
+        reassociation[i] = ReassociationIndices{i + offsetPos + 1};
+    }
   }
   return reassociation;
 }
 
-static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) {
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+                                    int64_t pos, bool fromRight = false) {
   auto valType = cast<ShapedType>(val.getType());
+  SmallVector<int64_t> collapsedShape(valType.getShape());
+  collapsedShape.erase(collapsedShape.begin() + pos);
   return collapseValue(
-      rewriter, val.getLoc(), val, valType.getShape().drop_front(1),
-      getReassociationsForLeadingDims(valType.getRank()),
-      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
-}
-
-static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
-                                          Value val) {
-  auto valType = cast<ShapedType>(val.getType());
-  return collapseValue(
-      rewriter, val.getLoc(), val, valType.getShape().drop_back(1),
-      getReassociationsForTrailingDims(valType.getRank()),
+      rewriter, val.getLoc(), val, collapsedShape,
+      getReassociationForReshapeAtDim(valType.getRank(), pos, fromRight),
       ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
 }
 
@@ -951,16 +939,16 @@ struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
 
   SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
                                          Value rhs, Value init) const override {
-    auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
-    auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs);
-    auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
+    auto collapsedLhs = collapseSingletonDimAt(rewriter, lhs, 0);
+    auto collapsedRhs = collapseSingletonDimAt(rewriter, rhs, 0);
+    auto collapsedInit = collapseSingletonDimAt(rewriter, init, 0);
     return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
   }
   Value expandResult(PatternRewriter &rewriter, Value result,
                      RankedTensorType expandedType) const override {
-  return rewriter.create<tensor::ExpandShapeOp>(
-      result.getLoc(), expandedType, result,
-      getReassociationsForLeadingDims(expandedType.getRank()));
+    return rewriter.create<tensor::ExpandShapeOp>(
+        result.getLoc(), expandedType, result,
+        getReassociationForReshapeAtDim(expandedType.getRank(), 0));
   }
 };
 
@@ -968,11 +956,32 @@ template <typename FromOpTy, typename ToOpTy>
 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
 
+  static bool constexpr isBatched =
+      std::is_same<FromOpTy, BatchMatmulOp>::value ||
+      std::is_same<FromOpTy, BatchMatvecOp>::value ||
+      std::is_same<FromOpTy, BatchVecmatOp>::value ||
+      std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
+      std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value;
+
+  static bool constexpr isLHSTransposed =
+      std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
+      std::is_same<FromOpTy, MatmulTransposeAOp>::value;
+
+  static bool constexpr isRHSTransposed =
+      std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
+      std::is_same<FromOpTy, MatmulTransposeBOp>::value;
+
   static bool constexpr isTranspose =
+      std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
+      std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
       std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
       std::is_same<FromOpTy, MatmulTransposeBOp>::value;
 
   static bool constexpr reduceLeft =
+      (std::is_same<FromOpTy, BatchMatmulOp>::value &&
+       std::is_same<ToOpTy, BatchVecmatOp>::value) ||
+      (std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value &&
+       std::is_same<ToOpTy, BatchVecmatOp>::value) ||
       (std::is_same<FromOpTy, MatmulOp>::value &&
        std::is_same<ToOpTy, VecmatOp>::value) ||
       (std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
@@ -980,37 +989,41 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
       (std::is_same<FromOpTy, MatvecOp>::value &&
        std::is_same<ToOpTy, DotOp>::value);
 
+  static int constexpr lhsTransposeOffset = (int)isLHSTransposed;
+  static int constexpr rhsTransposeOffset = (int)isRHSTransposed;
+  static int constexpr batchOffset = (int)isBatched;
+
   bool checkTypes(Value lhs, Value rhs, Value init) const override {
     auto lhsType = cast<ShapedType>(lhs.getType());
     auto rhsType = cast<ShapedType>(rhs.getType());
     auto initType = cast<ShapedType>(init.getType());
-    int constexpr offset = (int)isTranspose;
     if constexpr (reduceLeft)
-      return lhsType.getShape().begin()[offset] == 1 &&
-             initType.getShape().begin()[offset] == 1;
+      return lhsType.getShape().begin()[lhsTransposeOffset + batchOffset] ==
+                 1 &&
+             initType.getShape().begin()[lhsTransposeOffset + batchOffset] == 1;
     else
-      return rhsType.getShape().rbegin()[offset] == 1 &&
-             initType.getShape().rbegin()[offset] == 1;
+      return rhsType.getShape().rbegin()[rhsTransposeOffset] == 1 &&
+             initType.getShape().rbegin()[rhsTransposeOffset] == 1;
   }
 
   SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
                                          Value rhs, Value init) const override {
+
     if constexpr (reduceLeft) {
-      if constexpr (isTranspose) {
-        lhs = collapseTrailingSingletonDim(rewriter, lhs);
-        init = collapseTrailingSingletonDim(rewriter, init);
-      } else {
-        lhs = collapseLeadingSingletonDim(rewriter, lhs);
-        init = collapseLeadingSingletonDim(rewriter, init);
-      }
+      lhs = collapseSingletonDimAt(rewriter, lhs,
+                                   lhsTransposeOffset + batchOffset,
+                                   /*fromRight=*/isLHSTransposed);
+      init = collapseSingletonDimAt(rewriter, init,
+                                    lhsTransposeOffset + batchOffset,
+                                    /*fromRight*/ isLHSTransposed);
     } else {
-      if constexpr (isTranspose) {
-        rhs = collapseLeadingSingletonDim(rewriter, rhs);
-        init = collapseLeadingSingletonDim(rewriter, init);
-      } else {
-        rhs = collapseTrailingSingletonDim(rewriter, rhs);
-        init = collapseTrailingSingletonDim(rewriter, init);
-      }
+      auto rhsRank = cast<ShapedType>(rhs.getType()).getRank();
+      auto initRank = cast<ShapedType>(init.getType()).getRank();
+      rhs = collapseSingletonDimAt(
+          rewriter, rhs, rhsRank - rhsTransposeOffset - 1, /*fromRight=*/true);
+      init = collapseSingletonDimAt(rewriter, init,
+                                    initRank - rhsTransposeOffset - 1,
+                                    /*fromRight=*/true);
     }
     return SmallVector<Value, 3>{lhs, rhs, init};
   }
@@ -1020,11 +1033,13 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
     if constexpr (reduceLeft)
       return rewriter.create<tensor::ExpandShapeOp>(
           result.getLoc(), expandedType, result,
-          getReassociationsForLeadingDims(expandedType.getRank()));
+          getReassociationForReshapeAtDim(expandedType.getRank(), 0));
     else
       return rewriter.create<tensor::ExpandShapeOp>(
           result.getLoc(), expandedType, result,
-          getReassociationsForTrailingDims(expandedType.getRank()));
+          getReassociationForReshapeAtDim(expandedType.getRank(),
+                                          expandedType.getRank() - 1,
+                                          /*fromRight=*/true));
   }
 };
 
@@ -1033,6 +1048,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
 void mlir::linalg::populateContractionOpRankReducingPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
+  // Unbatching patterns for unit batch size
   patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
   patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
       context);
@@ -1040,10 +1056,21 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
       context);
   patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
   patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
+
+  // Non-batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
   patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
+  // Batch rank 1 reducing patterns
+  patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
+  patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
+  patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
+      context);
+  patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
+      context);
+
+  // Non-batch rank 0 reducing patterns
   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
   patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
 }

>From 7851ae12b78d9b4fbb947de3a6c9fa1654fe0ab6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 20 Jun 2024 19:13:36 -0500
Subject: [PATCH 11/13] more refactor

---
 .../Linalg/Transforms/DropUnitDims.cpp        | 229 +++++++++---------
 .../Linalg/rank-reduce-contraction-ops.mlir   |  55 ++---
 2 files changed, 141 insertions(+), 143 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 26a85e8225b24..771b40d6a2001 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -839,31 +839,32 @@ struct LinalgFoldUnitExtentDimsPass
 namespace {
 
 static SmallVector<ReassociationIndices>
-getReassociationForReshapeAtDim(int64_t rank, int64_t pos,
-                                bool fromRight = false) {
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
   SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+  auto lastDim = pos == rank - 1;
   if (rank > 2) {
-    int64_t offsetPos = pos - (int64_t)fromRight;
     for (int64_t i = 0; i < rank - 1; i++) {
-      if (i == offsetPos)
+      if (i == pos || (lastDim && i == pos - 1))
         reassociation[i] = ReassociationIndices{i, i + 1};
-      else if (i < offsetPos)
+      else if (i < pos)
         reassociation[i] = ReassociationIndices{i};
       else
-        reassociation[i] = ReassociationIndices{i + offsetPos + 1};
+        reassociation[i] = ReassociationIndices{i + 1};
     }
   }
   return reassociation;
 }
 
 static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
-                                    int64_t pos, bool fromRight = false) {
+                                    int64_t pos) {
+  if (pos < 0)
+    return val;
   auto valType = cast<ShapedType>(val.getType());
   SmallVector<int64_t> collapsedShape(valType.getShape());
   collapsedShape.erase(collapsedShape.begin() + pos);
   return collapseValue(
       rewriter, val.getLoc(), val, collapsedShape,
-      getReassociationForReshapeAtDim(valType.getRank(), pos, fromRight),
+      getReassociationForReshapeAtDim(valType.getRank(), pos),
       ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
 }
 
@@ -871,24 +872,52 @@ template <typename FromOpTy, typename ToOpTy>
 struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
   using OpRewritePattern<FromOpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(FromOpTy batchMatmulOp,
+  SmallVector<Value, 3>
+  collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+                   ArrayRef<int64_t> operandCollapseDims) const {
+    assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+           "expected 3 operands and dims");
+    return llvm::to_vector(llvm::map_range(
+        llvm::zip(operands, operandCollapseDims), [&](auto pair) {
+          return collapseSingletonDimAt(rewriter, std::get<0>(pair),
+                                        std::get<1>(pair));
+        }));
+  }
+
+  Value expandResult(PatternRewriter &rewriter, Value result,
+                     RankedTensorType expandedType, int64_t dim) const {
+    return rewriter.create<tensor::ExpandShapeOp>(
+        result.getLoc(), expandedType, result,
+        getReassociationForReshapeAtDim(expandedType.getRank(), dim));
+  }
+
+  LogicalResult matchAndRewrite(FromOpTy contractionOp,
                                 PatternRewriter &rewriter) const override {
 
-    auto loc = batchMatmulOp.getLoc();
-    auto inputs = batchMatmulOp.getDpsInputs();
-    auto inits = batchMatmulOp.getDpsInits();
+    auto loc = contractionOp.getLoc();
+    auto inputs = contractionOp.getDpsInputs();
+    auto inits = contractionOp.getDpsInits();
     if (inputs.size() != 2 || inits.size() != 1)
-      return rewriter.notifyMatchFailure(batchMatmulOp,
+      return rewriter.notifyMatchFailure(contractionOp,
                                          "expected 2 inputs and 1 init");
     auto lhs = inputs[0];
     auto rhs = inputs[1];
     auto init = inits[0];
+    SmallVector<Value> operands{lhs, rhs, init};
 
-    if (!checkTypes(lhs, rhs, init))
-      return rewriter.notifyMatchFailure(batchMatmulOp,
+    auto maybeContractionDims = inferContractionDims(contractionOp);
+    if (failed(maybeContractionDims))
+      return rewriter.notifyMatchFailure(contractionOp,
+                                         "could not infer contraction dims");
+
+    auto contractionDims = maybeContractionDims.value();
+    SmallVector<int64_t> operandUnitDims;
+    if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
+      return rewriter.notifyMatchFailure(contractionOp,
                                          "no reducable dims found");
 
-    auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init);
+    auto collapsedOperands =
+        collapseOperands(rewriter, operands, operandUnitDims);
     auto collapsedLhs = collapsedOperands[0];
     auto collapsedRhs = collapsedOperands[1];
     auto collapsedInit = collapsedOperands[2];
@@ -898,57 +927,63 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
     auto collapsedOp = rewriter.create<ToOpTy>(
         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
         ValueRange{collapsedInit});
-    for (auto attr : batchMatmulOp->getAttrs()) {
+    for (auto attr : contractionOp->getAttrs()) {
       if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
         continue;
       collapsedOp->setAttr(attr.getName(), attr.getValue());
     }
 
-    auto results = batchMatmulOp.getResults();
+    auto results = contractionOp.getResults();
     assert(results.size() < 2 && "expected at most one result");
     if (results.size() < 1)
-      rewriter.replaceOp(batchMatmulOp, collapsedOp);
+      rewriter.replaceOp(contractionOp, collapsedOp);
     else
       rewriter.replaceOp(
-          batchMatmulOp,
+          contractionOp,
           expandResult(rewriter, collapsedOp.getResultTensors()[0],
-                       cast<RankedTensorType>(results[0].getType())));
+                       cast<RankedTensorType>(results[0].getType()),
+                       operandUnitDims[2]));
 
     return success();
   }
 
-  virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0;
-  virtual SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter,
-                                                 Value lhs, Value rhs,
-                                                 Value init) const = 0;
-  virtual Value expandResult(PatternRewriter &rewriter, Value result,
-                             RankedTensorType expandedType) const = 0;
+  virtual LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDindices) const = 0;
 };
 
 template <typename FromOpTy, typename ToOpTy>
 struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
 
-  bool checkTypes(Value lhs, Value rhs, Value init) const override {
-    auto lhsType = cast<ShapedType>(lhs.getType());
-    auto rhsType = cast<ShapedType>(rhs.getType());
-    auto initType = cast<ShapedType>(init.getType());
-    return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 &&
-           initType.getShape()[0] == 1;
-  }
+  LogicalResult getOperandUnitDims(
+      LinalgOp op,
+      SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+    auto inputs = op.getDpsInputs();
+    auto inits = op.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return failure();
 
-  SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
-                                         Value rhs, Value init) const override {
-    auto collapsedLhs = collapseSingletonDimAt(rewriter, lhs, 0);
-    auto collapsedRhs = collapseSingletonDimAt(rewriter, rhs, 0);
-    auto collapsedInit = collapseSingletonDimAt(rewriter, init, 0);
-    return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
-  }
-  Value expandResult(PatternRewriter &rewriter, Value result,
-                     RankedTensorType expandedType) const override {
-    return rewriter.create<tensor::ExpandShapeOp>(
-        result.getLoc(), expandedType, result,
-        getReassociationForReshapeAtDim(expandedType.getRank(), 0));
+    auto maybeContractionDims = inferContractionDims(op);
+    if (failed(maybeContractionDims))
+      return failure();
+    auto contractionDims = maybeContractionDims.value();
+
+    if (contractionDims.batch.size() != 1)
+      return failure();
+    auto batchDim = contractionDims.batch[0];
+    SmallVector<std::pair<Value, unsigned>, 2> bOperands;
+    op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
+    if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
+          return cast<ShapedType>(std::get<0>(pair).getType())
+                     .getShape()[std::get<1>(pair)] != 1;
+        }))
+      return failure();
+
+    operandUnitDindices = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+                                               std::get<1>(bOperands[1]),
+                                               std::get<1>(bOperands[2])};
+    return success();
   }
 };
 
@@ -956,27 +991,6 @@ template <typename FromOpTy, typename ToOpTy>
 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
 
-  static bool constexpr isBatched =
-      std::is_same<FromOpTy, BatchMatmulOp>::value ||
-      std::is_same<FromOpTy, BatchMatvecOp>::value ||
-      std::is_same<FromOpTy, BatchVecmatOp>::value ||
-      std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
-      std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value;
-
-  static bool constexpr isLHSTransposed =
-      std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
-      std::is_same<FromOpTy, MatmulTransposeAOp>::value;
-
-  static bool constexpr isRHSTransposed =
-      std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
-      std::is_same<FromOpTy, MatmulTransposeBOp>::value;
-
-  static bool constexpr isTranspose =
-      std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
-      std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
-      std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
-      std::is_same<FromOpTy, MatmulTransposeBOp>::value;
-
   static bool constexpr reduceLeft =
       (std::is_same<FromOpTy, BatchMatmulOp>::value &&
        std::is_same<ToOpTy, BatchVecmatOp>::value) ||
@@ -989,57 +1003,44 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
       (std::is_same<FromOpTy, MatvecOp>::value &&
        std::is_same<ToOpTy, DotOp>::value);
 
-  static int constexpr lhsTransposeOffset = (int)isLHSTransposed;
-  static int constexpr rhsTransposeOffset = (int)isRHSTransposed;
-  static int constexpr batchOffset = (int)isBatched;
-
-  bool checkTypes(Value lhs, Value rhs, Value init) const override {
-    auto lhsType = cast<ShapedType>(lhs.getType());
-    auto rhsType = cast<ShapedType>(rhs.getType());
-    auto initType = cast<ShapedType>(init.getType());
-    if constexpr (reduceLeft)
-      return lhsType.getShape().begin()[lhsTransposeOffset + batchOffset] ==
-                 1 &&
-             initType.getShape().begin()[lhsTransposeOffset + batchOffset] == 1;
-    else
-      return rhsType.getShape().rbegin()[rhsTransposeOffset] == 1 &&
-             initType.getShape().rbegin()[rhsTransposeOffset] == 1;
-  }
-
-  SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
-                                         Value rhs, Value init) const override {
+  LogicalResult getOperandUnitDims(
+      LinalgOp op,
+      SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+    auto maybeContractionDims = inferContractionDims(op);
+    if (failed(maybeContractionDims))
+      return failure();
+    auto contractionDims = maybeContractionDims.value();
 
     if constexpr (reduceLeft) {
-      lhs = collapseSingletonDimAt(rewriter, lhs,
-                                   lhsTransposeOffset + batchOffset,
-                                   /*fromRight=*/isLHSTransposed);
-      init = collapseSingletonDimAt(rewriter, init,
-                                    lhsTransposeOffset + batchOffset,
-                                    /*fromRight*/ isLHSTransposed);
+      auto m = contractionDims.m[0];
+      SmallVector<std::pair<Value, unsigned>, 2> mOperands;
+      op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
+      if (mOperands.size() != 2)
+        return failure();
+      if (llvm::all_of(mOperands, [](auto pair) {
+            return cast<ShapedType>(std::get<0>(pair).getType())
+                       .getShape()[std::get<1>(pair)] == 1;
+          })) {
+        operandUnitDindices = SmallVector<int64_t>{
+            std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])};
+        return success();
+      }
     } else {
-      auto rhsRank = cast<ShapedType>(rhs.getType()).getRank();
-      auto initRank = cast<ShapedType>(init.getType()).getRank();
-      rhs = collapseSingletonDimAt(
-          rewriter, rhs, rhsRank - rhsTransposeOffset - 1, /*fromRight=*/true);
-      init = collapseSingletonDimAt(rewriter, init,
-                                    initRank - rhsTransposeOffset - 1,
-                                    /*fromRight=*/true);
+      auto n = contractionDims.n[0];
+      SmallVector<std::pair<Value, unsigned>, 2> nOperands;
+      op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
+      if (nOperands.size() != 2)
+        return failure();
+      if (llvm::all_of(nOperands, [](auto pair) {
+            return cast<ShapedType>(std::get<0>(pair).getType())
+                       .getShape()[std::get<1>(pair)] == 1;
+          })) {
+        operandUnitDindices = SmallVector<int64_t>{
+            -1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])};
+        return success();
+      }
     }
-    return SmallVector<Value, 3>{lhs, rhs, init};
-  }
-
-  Value expandResult(PatternRewriter &rewriter, Value result,
-                     RankedTensorType expandedType) const override {
-    if constexpr (reduceLeft)
-      return rewriter.create<tensor::ExpandShapeOp>(
-          result.getLoc(), expandedType, result,
-          getReassociationForReshapeAtDim(expandedType.getRank(), 0));
-    else
-      return rewriter.create<tensor::ExpandShapeOp>(
-          result.getLoc(), expandedType, result,
-          getReassociationForReshapeAtDim(expandedType.getRank(),
-                                          expandedType.getRank() - 1,
-                                          /*fromRight=*/true));
+    return failure();
   }
 };
 
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 0548f3f860a89..70568be99474e 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -1,23 +1,19 @@
 //RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s
 
-func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
   // CHECK-LABEL: @singleton_batch_matmul_tensor
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
-  //  CHECK-DAG:    %[[C2:.*]] = arith.constant 2
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512x256xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128x256xf32>
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
   //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
-  //  CHECK-NEXT:   %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
-  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512x256xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128x256xf32>)
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, 128, 256]
   //  CHECK-NEXT:   return %[[RES]]
-  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
-      outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
-  return %1 : tensor<1x?x?xf32>
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512x256xf32>)
+      outs(%arg2 : tensor<1x128x256xf32>) -> tensor<1x128x256xf32>
+  return %1 : tensor<1x128x256xf32>
 }
 
 // -----
@@ -39,22 +35,20 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr
 
 // -----
 
-func.func @singleton_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
   // CHECK-LABEL: @singleton_batch_matvec
-  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
-  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128xf32>
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
   //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
-  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
   //  CHECK-NEXT:   return %[[RES]]
-  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
-      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
-  return %1 : tensor<1x?xf32>
+  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
+      outs(%arg2 : tensor<1x128xf32>) -> tensor<1x128xf32>
+  return %1 : tensor<1x128xf32>
 }
 
 // -----
@@ -197,19 +191,22 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
 
 // -----
 
-func.func @matmul_transpose_a_to_vecmat(%arg0: memref<?x1xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x1xf32>) {
+func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
   // CHECK-LABEL: @matmul_transpose_a_to_vecmat
+  // CHECK: collapse_shape {{.*}} into tensor<256xf32>
+  // CHECK: collapse_shape {{.*}} into tensor<512xf32>
   // CHECK: linalg.vecmat
-    linalg.matmul_transpose_a ins(%arg0, %arg1: memref<?x1xf32>, memref<?x?xf32>) outs(%arg2: memref<?x1xf32>)
-    return
+  // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
+    %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
+    return %0 : tensor<1x512xf32>
 }
 
 // -----
 
-func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
+func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
   // CHECK-LABEL: @matmul_transpose_b_to_matvec
   // CHECK: linalg.matvec
-    linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<1x?xf32>)
+    linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
     return
 }
 

>From acca39bc577ec6077ef71740b4176ccf65826a05 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 21 Jun 2024 11:14:03 -0500
Subject: [PATCH 12/13] address comments and extra cleanup

---
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +-
 .../Linalg/Transforms/DropUnitDims.cpp        | 126 ++++++++++--------
 .../Linalg/rank-reduce-contraction-ops.mlir   |   9 ++
 .../TestLinalgRankReduceContractionOps.cpp    |   3 +-
 4 files changed, 84 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c49383c600a57..3682a68b0e2c8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,8 +1692,8 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
-/// Adds patterns that that reduce the rank of named contraction ops that have
-/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`,
+/// Adds patterns that reduce the rank of named contraction ops that have
+/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
 /// `<corresponding linalg named op>`, `expand_shape` (if on tensors).  For example a
 /// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
 /// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 771b40d6a2001..e1daeb3ad666e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -838,10 +838,12 @@ struct LinalgFoldUnitExtentDimsPass
 
 namespace {
 
+/// Returns reassociation indices for collapsing/expanding a
+/// tensor of rank `rank` at position `pos`.
 static SmallVector<ReassociationIndices>
 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
   SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
-  auto lastDim = pos == rank - 1;
+  bool lastDim = pos == rank - 1;
   if (rank > 2) {
     for (int64_t i = 0; i < rank - 1; i++) {
       if (i == pos || (lastDim && i == pos - 1))
@@ -855,6 +857,8 @@ getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
   return reassociation;
 }
 
+/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
+/// If `pos < 0`, then don't collapse.
 static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
                                     int64_t pos) {
   if (pos < 0)
@@ -868,22 +872,30 @@ static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
       ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
 }
 
+/// Base class for all rank reduction patterns for contraction ops
+/// with unit dimensions.  All patterns should convert one named op
+/// to another named op.  Intended to reduce only one iteration space dim
+/// at a time.
+/// Reducing multiple dims will happen with recusive application of
+/// pattern rewrites.
 template <typename FromOpTy, typename ToOpTy>
 struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
   using OpRewritePattern<FromOpTy>::OpRewritePattern;
 
-  SmallVector<Value, 3>
+  /// Collapse all collapsable operands.
+  SmallVector<Value>
   collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
                    ArrayRef<int64_t> operandCollapseDims) const {
     assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
            "expected 3 operands and dims");
-    return llvm::to_vector(llvm::map_range(
+    return llvm::map_to_vector(
         llvm::zip(operands, operandCollapseDims), [&](auto pair) {
           return collapseSingletonDimAt(rewriter, std::get<0>(pair),
                                         std::get<1>(pair));
-        }));
+        });
   }
 
+  /// Expand result tensor.
   Value expandResult(PatternRewriter &rewriter, Value result,
                      RankedTensorType expandedType, int64_t dim) const {
     return rewriter.create<tensor::ExpandShapeOp>(
@@ -905,12 +917,6 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
     auto init = inits[0];
     SmallVector<Value> operands{lhs, rhs, init};
 
-    auto maybeContractionDims = inferContractionDims(contractionOp);
-    if (failed(maybeContractionDims))
-      return rewriter.notifyMatchFailure(contractionOp,
-                                         "could not infer contraction dims");
-
-    auto contractionDims = maybeContractionDims.value();
     SmallVector<int64_t> operandUnitDims;
     if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
       return rewriter.notifyMatchFailure(contractionOp,
@@ -935,80 +941,89 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
 
     auto results = contractionOp.getResults();
     assert(results.size() < 2 && "expected at most one result");
-    if (results.size() < 1)
+    if (results.empty()) {
       rewriter.replaceOp(contractionOp, collapsedOp);
-    else
+    } else {
       rewriter.replaceOp(
           contractionOp,
           expandResult(rewriter, collapsedOp.getResultTensors()[0],
                        cast<RankedTensorType>(results[0].getType()),
                        operandUnitDims[2]));
+    }
 
     return success();
   }
 
+  /// Populate `operandUnitDims` with 3 indices indicating the unit dim
+  /// for each operand that should be collapsed in this pattern.  If an
+  /// operand shouldn't be collapsed, the index should be negative.
   virtual LogicalResult
   getOperandUnitDims(LinalgOp op,
-                     SmallVectorImpl<int64_t> &operandUnitDindices) const = 0;
+                     SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
 };
 
+/// Patterns for unbatching batched contraction ops
 template <typename FromOpTy, typename ToOpTy>
-struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
 
-  LogicalResult getOperandUnitDims(
-      LinalgOp op,
-      SmallVectorImpl<int64_t> &operandUnitDindices) const override {
-    auto inputs = op.getDpsInputs();
-    auto inits = op.getDpsInits();
-    if (inputs.size() != 2 || inits.size() != 1)
-      return failure();
-
+  /// Look for unit batch dims to collapse.
+  LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDims) const override {
     auto maybeContractionDims = inferContractionDims(op);
-    if (failed(maybeContractionDims))
+    if (failed(maybeContractionDims)) {
+      LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
       return failure();
+    }
     auto contractionDims = maybeContractionDims.value();
 
     if (contractionDims.batch.size() != 1)
       return failure();
     auto batchDim = contractionDims.batch[0];
-    SmallVector<std::pair<Value, unsigned>, 2> bOperands;
+    SmallVector<std::pair<Value, unsigned>, 3> bOperands;
     op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
     if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
           return cast<ShapedType>(std::get<0>(pair).getType())
                      .getShape()[std::get<1>(pair)] != 1;
-        }))
+        })) {
+      LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
       return failure();
+    }
 
-    operandUnitDindices = SmallVector<int64_t>{std::get<1>(bOperands[0]),
-                                               std::get<1>(bOperands[1]),
-                                               std::get<1>(bOperands[2])};
+    operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+                                           std::get<1>(bOperands[1]),
+                                           std::get<1>(bOperands[2])};
     return success();
   }
 };
 
+/// Patterns for reducing non-batch dimensions
 template <typename FromOpTy, typename ToOpTy>
 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
 
+  /// Helper for determining whether the lhs/init or rhs/init are reduced.
   static bool constexpr reduceLeft =
-      (std::is_same<FromOpTy, BatchMatmulOp>::value &&
-       std::is_same<ToOpTy, BatchVecmatOp>::value) ||
-      (std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value &&
-       std::is_same<ToOpTy, BatchVecmatOp>::value) ||
-      (std::is_same<FromOpTy, MatmulOp>::value &&
-       std::is_same<ToOpTy, VecmatOp>::value) ||
-      (std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
-       std::is_same<ToOpTy, VecmatOp>::value) ||
-      (std::is_same<FromOpTy, MatvecOp>::value &&
-       std::is_same<ToOpTy, DotOp>::value);
-
-  LogicalResult getOperandUnitDims(
-      LinalgOp op,
-      SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+      (std::is_same_v<FromOpTy, BatchMatmulOp> &&
+       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+      (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
+       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+      (std::is_same_v<FromOpTy, MatmulOp> &&
+       std::is_same_v<ToOpTy, VecmatOp>) ||
+      (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
+       std::is_same_v<ToOpTy, VecmatOp>) ||
+      (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
+
+  /// Look for non-batch spatial dims to collapse.
+  LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDims) const override {
     auto maybeContractionDims = inferContractionDims(op);
-    if (failed(maybeContractionDims))
+    if (failed(maybeContractionDims)) {
+      LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
       return failure();
+    }
     auto contractionDims = maybeContractionDims.value();
 
     if constexpr (reduceLeft) {
@@ -1021,8 +1036,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
             return cast<ShapedType>(std::get<0>(pair).getType())
                        .getShape()[std::get<1>(pair)] == 1;
           })) {
-        operandUnitDindices = SmallVector<int64_t>{
-            std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])};
+        operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
+                                               std::get<1>(mOperands[1])};
         return success();
       }
     } else {
@@ -1035,11 +1050,12 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
             return cast<ShapedType>(std::get<0>(pair).getType())
                        .getShape()[std::get<1>(pair)] == 1;
           })) {
-        operandUnitDindices = SmallVector<int64_t>{
-            -1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])};
+        operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
+                                               std::get<1>(nOperands[1])};
         return success();
       }
     }
+    LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
     return failure();
   }
 };
@@ -1050,13 +1066,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
   // Unbatching patterns for unit batch size
-  patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
-  patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-      context);
-  patterns.add<RankReduceBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-      context);
-  patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
-  patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
+  patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
+  patterns
+      .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+          context);
+  patterns
+      .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+          context);
+  patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
+  patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
 
   // Non-batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 70568be99474e..fd59a4a52e378 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -212,6 +212,15 @@ func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x
 
 // -----
 
+func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
+  // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
+  // CHECK: linalg.dot
+    %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> 
+    return %0 : tensor<1x1x1xf32>
+}
+
+// -----
+
 func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
   // CHECK-LABEL: @nonsingleton_batch_matmul
   // CHECK-NOT:   collapse_shape
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
index 5ca27be30a687..8b455d7d68c30 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
@@ -1,5 +1,4 @@
-//===- TestLinalgRankReduceContractionOps.cpp - Test Linalg rank reduce
-//contractions ---===//
+//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From 01197ef9f661bc40348c48fd907aab7f1dc118b3 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 21 Jun 2024 11:29:46 -0500
Subject: [PATCH 13/13] add more tests

---
 .../Linalg/rank-reduce-contraction-ops.mlir   | 23 +++++++++++++++++++
 1 file changed, 23 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index fd59a4a52e378..c086d0fd7e633 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -203,6 +203,18 @@ func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<
 
 // -----
 
+func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
+  // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
+  // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+  // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
+  // CHECK: linalg.batch_vecmat
+  // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
+    %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
+    return %0 : tensor<64x1x512xf32>
+}
+
+// -----
+
 func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
   // CHECK-LABEL: @matmul_transpose_b_to_matvec
   // CHECK: linalg.matvec
@@ -212,6 +224,17 @@ func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x
 
 // -----
 
+func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
+  // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+  // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
+  // CHECK: linalg.batch_matvec
+  // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
+    %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> 
+    return %0 : tensor<64x128x1xf32>
+}
+
+// -----
+
 func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
   // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
   // CHECK: linalg.dot



More information about the Mlir-commits mailing list