[Mlir-commits] [mlir] 431213c - [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (#95710)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 24 11:06:35 PDT 2024


Author: srcarroll
Date: 2024-06-24T13:06:31-05:00
New Revision: 431213c99d7707114d8e7956073a057cf1607160

URL: https://github.com/llvm/llvm-project/commit/431213c99d7707114d8e7956073a057cf1607160
DIFF: https://github.com/llvm/llvm-project/commit/431213c99d7707114d8e7956073a057cf1607160.diff

LOG: [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (#95710)

This patch introduces pattern rewrites for reducing the rank of named
linalg contraction ops with unit spatial dim(s) to other named
contraction ops. For example `linalg.batch_matmul` with batch size 1 ->
`linalg.matmul` and `linalg.matmul` with unit LHS spatial dim ->
`linalg.vecmat`, etc. These patterns don't support reducing the rank
along reduction dimension as those don't convert to other named
contraction ops.

Added: 
    mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8424207ea47e5..b0871a5dff5da 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1713,6 +1713,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Adds patterns that reduce the rank of named contraction ops that have
+/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors).  For example a
+/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
+/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c0829397f1f85..36f8696bf1b27 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
+
+} // namespace
+
+namespace {
+
+/// Returns reassociation indices for collapsing/expanding a
+/// tensor of rank `rank` at position `pos`.
+static SmallVector<ReassociationIndices>
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
+  SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+  bool lastDim = pos == rank - 1;
+  if (rank > 2) {
+    for (int64_t i = 0; i < rank - 1; i++) {
+      if (i == pos || (lastDim && i == pos - 1))
+        reassociation[i] = ReassociationIndices{i, i + 1};
+      else if (i < pos)
+        reassociation[i] = ReassociationIndices{i};
+      else
+        reassociation[i] = ReassociationIndices{i + 1};
+    }
+  }
+  return reassociation;
+}
+
+/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
+/// If `pos < 0`, then don't collapse.
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+                                    int64_t pos) {
+  if (pos < 0)
+    return val;
+  auto valType = cast<ShapedType>(val.getType());
+  SmallVector<int64_t> collapsedShape(valType.getShape());
+  collapsedShape.erase(collapsedShape.begin() + pos);
+  return collapseValue(
+      rewriter, val.getLoc(), val, collapsedShape,
+      getReassociationForReshapeAtDim(valType.getRank(), pos),
+      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+/// Base class for all rank reduction patterns for contraction ops
+/// with unit dimensions.  All patterns should convert one named op
+/// to another named op.  Intended to reduce only one iteration space dim
+/// at a time.
+/// Reducing multiple dims will happen with recusive application of
+/// pattern rewrites.
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+  using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+  /// Collapse all collapsable operands.
+  SmallVector<Value>
+  collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+                   ArrayRef<int64_t> operandCollapseDims) const {
+    assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+           "expected 3 operands and dims");
+    return llvm::map_to_vector(
+        llvm::zip(operands, operandCollapseDims), [&](auto pair) {
+          return collapseSingletonDimAt(rewriter, std::get<0>(pair),
+                                        std::get<1>(pair));
+        });
+  }
+
+  /// Expand result tensor.
+  Value expandResult(PatternRewriter &rewriter, Value result,
+                     RankedTensorType expandedType, int64_t dim) const {
+    return rewriter.create<tensor::ExpandShapeOp>(
+        result.getLoc(), expandedType, result,
+        getReassociationForReshapeAtDim(expandedType.getRank(), dim));
+  }
+
+  LogicalResult matchAndRewrite(FromOpTy contractionOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto loc = contractionOp.getLoc();
+    auto inputs = contractionOp.getDpsInputs();
+    auto inits = contractionOp.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return rewriter.notifyMatchFailure(contractionOp,
+                                         "expected 2 inputs and 1 init");
+    auto lhs = inputs[0];
+    auto rhs = inputs[1];
+    auto init = inits[0];
+    SmallVector<Value> operands{lhs, rhs, init};
+
+    SmallVector<int64_t> operandUnitDims;
+    if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
+      return rewriter.notifyMatchFailure(contractionOp,
+                                         "no reducable dims found");
+
+    SmallVector<Value> collapsedOperands =
+        collapseOperands(rewriter, operands, operandUnitDims);
+    Value collapsedLhs = collapsedOperands[0];
+    Value collapsedRhs = collapsedOperands[1];
+    Value collapsedInit = collapsedOperands[2];
+    SmallVector<Type, 1> collapsedResultTy;
+    if (isa<RankedTensorType>(collapsedInit.getType()))
+      collapsedResultTy.push_back(collapsedInit.getType());
+    auto collapsedOp = rewriter.create<ToOpTy>(
+        loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
+        ValueRange{collapsedInit});
+    for (auto attr : contractionOp->getAttrs()) {
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+        continue;
+      collapsedOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    auto results = contractionOp.getResults();
+    assert(results.size() < 2 && "expected at most one result");
+    if (results.empty()) {
+      rewriter.replaceOp(contractionOp, collapsedOp);
+    } else {
+      rewriter.replaceOp(
+          contractionOp,
+          expandResult(rewriter, collapsedOp.getResultTensors()[0],
+                       cast<RankedTensorType>(results[0].getType()),
+                       operandUnitDims[2]));
+    }
+
+    return success();
+  }
+
+  /// Populate `operandUnitDims` with 3 indices indicating the unit dim
+  /// for each operand that should be collapsed in this pattern.  If an
+  /// operand shouldn't be collapsed, the index should be negative.
+  virtual LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
+};
+
+/// Patterns for unbatching batched contraction ops
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+  /// Look for unit batch dims to collapse.
+  LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDims) const override {
+    FailureOr<ContractionDimensions> maybeContractionDims =
+        inferContractionDims(op);
+    if (failed(maybeContractionDims)) {
+      LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
+      return failure();
+    }
+    ContractionDimensions contractionDims = maybeContractionDims.value();
+
+    if (contractionDims.batch.size() != 1)
+      return failure();
+    auto batchDim = contractionDims.batch[0];
+    SmallVector<std::pair<Value, unsigned>, 3> bOperands;
+    op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
+    if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
+          return cast<ShapedType>(std::get<0>(pair).getType())
+                     .getShape()[std::get<1>(pair)] != 1;
+        })) {
+      LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
+      return failure();
+    }
+
+    operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+                                           std::get<1>(bOperands[1]),
+                                           std::get<1>(bOperands[2])};
+    return success();
+  }
+};
+
+/// Patterns for reducing non-batch dimensions
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
+  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+  /// Helper for determining whether the lhs/init or rhs/init are reduced.
+  static bool constexpr reduceLeft =
+      (std::is_same_v<FromOpTy, BatchMatmulOp> &&
+       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+      (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
+       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+      (std::is_same_v<FromOpTy, MatmulOp> &&
+       std::is_same_v<ToOpTy, VecmatOp>) ||
+      (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
+       std::is_same_v<ToOpTy, VecmatOp>) ||
+      (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
+
+  /// Look for non-batch spatial dims to collapse.
+  LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDims) const override {
+    FailureOr<ContractionDimensions> maybeContractionDims =
+        inferContractionDims(op);
+    if (failed(maybeContractionDims)) {
+      LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
+      return failure();
+    }
+    ContractionDimensions contractionDims = maybeContractionDims.value();
+
+    if constexpr (reduceLeft) {
+      auto m = contractionDims.m[0];
+      SmallVector<std::pair<Value, unsigned>, 2> mOperands;
+      op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
+      if (mOperands.size() != 2)
+        return failure();
+      if (llvm::all_of(mOperands, [](auto pair) {
+            return cast<ShapedType>(std::get<0>(pair).getType())
+                       .getShape()[std::get<1>(pair)] == 1;
+          })) {
+        operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
+                                               std::get<1>(mOperands[1])};
+        return success();
+      }
+    } else {
+      auto n = contractionDims.n[0];
+      SmallVector<std::pair<Value, unsigned>, 2> nOperands;
+      op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
+      if (nOperands.size() != 2)
+        return failure();
+      if (llvm::all_of(nOperands, [](auto pair) {
+            return cast<ShapedType>(std::get<0>(pair).getType())
+                       .getShape()[std::get<1>(pair)] == 1;
+          })) {
+        operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
+                                               std::get<1>(nOperands[1])};
+        return success();
+      }
+    }
+    LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
+    return failure();
+  }
+};
+
 } // namespace
+
+void mlir::linalg::populateContractionOpRankReducingPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  // Unbatching patterns for unit batch size
+  patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
+  patterns
+      .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+          context);
+  patterns
+      .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+          context);
+  patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
+  patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
+
+  // Non-batch rank 1 reducing patterns
+  patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
+  patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
+  // Batch rank 1 reducing patterns
+  patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
+  patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
+  patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
+      context);
+  patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
+      context);
+
+  // Non-batch rank 0 reducing patterns
+  patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
+  patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
+}

diff  --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
new file mode 100644
index 0000000000000..c086d0fd7e633
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -0,0 +1,267 @@
+//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
+  // CHECK-LABEL: @singleton_batch_matmul_tensor
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512x256xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128x256xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512x256xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128x256xf32>)
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, 128, 256]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512x256xf32>)
+      outs(%arg2 : tensor<1x128x256xf32>) -> tensor<1x128x256xf32>
+  return %1 : tensor<1x128x256xf32>
+}
+
+// -----
+
+func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+  // CHECK-LABEL: @singleton_batch_matmul_memref
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+      outs(%arg2 : memref<1x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
+  // CHECK-LABEL: @singleton_batch_matvec
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
+      outs(%arg2 : tensor<1x128xf32>) -> tensor<1x128xf32>
+  return %1 : tensor<1x128xf32>
+}
+
+// -----
+
+func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singleton_batch_vecmat
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singleton_batchmatmul_transpose_a
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singleton_batchmatmul_transpose_b
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
+  // CHECK-LABEL: @matmul_to_matvec_tensor
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+  //  CHECK-DAG:    %[[C0:.*]] = arith.constant 0
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
+  //  CHECK-NEXT:   return %[[RES]]
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
+    return %0 : tensor<?x1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+  // CHECK-LABEL: @matmul_to_matvec
+  // CHECK: linalg.matvec
+    linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
+    return
+}
+
+// -----
+
+func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @matmul_to_vecmat
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
+    return %0 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) {
+  // CHECK-LABEL: @batch_matmul_to_vecmat
+  // CHECK: linalg.vecmat
+    linalg.batch_matmul ins(%arg0, %arg1: memref<1x1x?xf32>, memref<1x?x?xf32>) outs(%arg2: memref<1x1x?xf32>)
+    return
+}
+
+// -----
+
+func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref<?xf32>, %arg2: memref<1xf32>) {
+  // CHECK-LABEL: @matvec_to_dot
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] []
+  //  CHECK-NEXT:   linalg.dot ins(%[[COLLAPSED_LHS]], %[[RHS]] : memref<?xf32>, memref<?xf32>) outs(%[[COLLAPSED_INIT]] : memref<f32>)
+    linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref<?xf32>) outs(%arg2: memref<1xf32>)
+    return
+}
+
+// -----
+
+func.func @vecmat_to_dot(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<1xf32>) {
+  // CHECK-LABEL: @vecmat_to_dot
+  // CHECK: linalg.dot
+    linalg.vecmat ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<1xf32>)
+    return
+}
+
+// -----
+
+func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> {
+  // CHECK-LABEL: @matvec_to_dot_tensor
+  // CHECK: linalg.dot
+    %0 = linalg.matvec ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?xf32>) outs(%arg2: tensor<1xf32>) -> tensor<1xf32>
+    return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
+  // CHECK-LABEL: @matmul_transpose_a_to_vecmat
+  // CHECK: collapse_shape {{.*}} into tensor<256xf32>
+  // CHECK: collapse_shape {{.*}} into tensor<512xf32>
+  // CHECK: linalg.vecmat
+  // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
+    %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
+    return %0 : tensor<1x512xf32>
+}
+
+// -----
+
+func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
+  // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
+  // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+  // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
+  // CHECK: linalg.batch_vecmat
+  // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
+    %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
+    return %0 : tensor<64x1x512xf32>
+}
+
+// -----
+
+func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
+  // CHECK-LABEL: @matmul_transpose_b_to_matvec
+  // CHECK: linalg.matvec
+    linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
+    return
+}
+
+// -----
+
+func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
+  // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+  // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
+  // CHECK: linalg.batch_matvec
+  // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
+    %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> 
+    return %0 : tensor<64x128x1xf32>
+}
+
+// -----
+
+func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
+  // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
+  // CHECK: linalg.dot
+    %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> 
+    return %0 : tensor<1x1x1xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+      outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+  return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index b28f2b3564662..283e426b4e594 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgDropUnitDims.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
+  TestLinalgRankReduceContractionOps.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
new file mode 100644
index 0000000000000..8b455d7d68c30
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
@@ -0,0 +1,67 @@
+//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing rank reduing patterns for named
+// contraction ops with unit dims.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLinalgRankReduceContractionOps
+    : public PassWrapper<TestLinalgRankReduceContractionOps,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestLinalgRankReduceContractionOps)
+
+  TestLinalgRankReduceContractionOps() = default;
+  TestLinalgRankReduceContractionOps(
+      const TestLinalgRankReduceContractionOps &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect, linalg::LinalgDialect,
+                    memref::MemRefDialect, tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-linalg-rank-reduce-contraction-ops";
+  }
+  StringRef getDescription() const final {
+    return "Test Linalg rank reduce contraction ops with unit dims";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    func::FuncOp funcOp = this->getOperation();
+
+    RewritePatternSet patterns(context);
+    linalg::populateContractionOpRankReducingPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                            std::move(patterns))))
+      return signalPassFailure();
+    return;
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgRankReduceContractionOps() {
+  PassRegistration<TestLinalgRankReduceContractionOps>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e4fbd03d8c678..8cafb0afac9ae 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -108,6 +108,7 @@ void registerTestLinalgDecomposeOps();
 void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
+void registerTestLinalgRankReduceContractionOps();
 void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
@@ -239,6 +240,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgDropUnitDims();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
+  mlir::test::registerTestLinalgRankReduceContractionOps();
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();


        


More information about the Mlir-commits mailing list