[Mlir-commits] [mlir] implement canonicalizer for batched linalg operations (PR #95710)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 16 10:13:13 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/95710

>From 0418e51cf33bc59cc6f19ed00edc8c2d62e4d9df Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 15 Jun 2024 10:46:44 -0500
Subject: [PATCH 1/3] implement canonicalizer for batched linalg operations

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  49 ++-----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 121 ++++++++++++++++++
 .../linalg/opdsl/ops/core_named_ops.py        |   5 +
 3 files changed, 138 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index fad234a9dcae9..3cbfb58ed8506 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: reciprocal
-  cpp_class_name: ReciprocalOp
-  doc: |-
-    Applies reciprocal(x) elementwise.
-
-    No numeric casting is performed on the input operand.
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: I
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<() -> ()>
-    - affine_map<() -> ()>
-  iterator_types: []
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: unary
-        fn_name: reciprocal
-        operands:
-        - !ScalarExpression
-          scalar_arg: I
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: round
   cpp_class_name: RoundOp
@@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: erf
-  cpp_class_name: erfOp
+  cpp_class_name: ErfOp
   doc: |-
     Applies erf(x) elementwise.
 
@@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: powf
-  cpp_class_name: PowFOp
+  cpp_class_name: PowfOp
   doc: |-
     Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
 
@@ -1622,6 +1587,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -1692,6 +1659,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -1762,6 +1731,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -2140,6 +2111,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -2208,6 +2181,8 @@ metadata: !LinalgOpMetadata
     them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..ecd669165efc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -42,6 +43,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <numeric>
 #include <optional>
 
 using namespace mlir;
@@ -578,6 +580,125 @@ class RegionBuilderHelper {
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename BatchOpTy, typename OpTy>
+struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
+  using OpRewritePattern<BatchOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto loc = batchMatmulOp.getLoc();
+    auto inputs = batchMatmulOp.getDpsInputs();
+    auto inits = batchMatmulOp.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp,
+                                         "expected 2 inputs and 1 init");
+    auto lhs = inputs[0];
+    auto rhs = inputs[1];
+    auto init = inits[0];
+
+    auto lhsType = cast<ShapedType>(lhs.getType());
+    auto rhsType = cast<ShapedType>(rhs.getType());
+    auto initType = cast<ShapedType>(init.getType());
+    if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
+        lhsType.getShape()[0] != rhsType.getShape()[0] ||
+        rhsType.getShape()[0] != initType.getShape()[0])
+      return rewriter.notifyMatchFailure(
+          batchMatmulOp, "expected batch sizes of all operands to be same");
+
+    auto results = batchMatmulOp.getResults();
+    if (results.size() > 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp,
+                                         "expected at most one result");
+
+    SmallVector<Type, 1> resultType;
+    if (results.size() == 1) {
+      auto oldResultType = cast<RankedTensorType>(results[0].getType());
+      resultType.push_back(
+          RankedTensorType::get(oldResultType.getShape().drop_front(1),
+                                oldResultType.getElementType()));
+    }
+
+    auto collapseSingletonDim = [&](Value val) -> Value {
+      SmallVector<ReassociationIndices> reassociation({{0, 1}});
+      auto valType = cast<ShapedType>(val.getType());
+      for (auto i = 2; i < valType.getRank(); i++)
+        reassociation.push_back({i});
+      if (isa<RankedTensorType>(valType)) {
+        RankedTensorType collapsedType = RankedTensorType::get(
+            valType.getShape().drop_front(1), valType.getElementType());
+        return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
+                                                        reassociation);
+      }
+      MemRefType collapsedType = MemRefType::get(
+          valType.getShape().drop_front(1), valType.getElementType());
+      return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
+                                                      reassociation);
+    };
+
+    auto collapsedLhs = collapseSingletonDim(lhs);
+    auto collapsedRhs = collapseSingletonDim(rhs);
+    auto collapsedInit = collapseSingletonDim(init);
+
+    auto collapsedOp = rewriter.create<OpTy>(
+        loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+        ValueRange{collapsedInit});
+    for (auto attr : batchMatmulOp->getAttrs()) {
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+        continue;
+      collapsedOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    if (results.size() < 1) {
+      rewriter.replaceOp(batchMatmulOp, collapsedOp);
+    } else {
+      SmallVector<ReassociationIndices> reassociation({{0, 1}});
+      auto resultType = cast<ShapedType>(results[0].getType());
+      for (auto i = 2; i < resultType.getRank(); i++)
+        reassociation.push_back({i});
+      Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
+          loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
+      rewriter.replaceOp(batchMatmulOp, expandedResult);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
+}
+
+void BatchMatmulTransposeAOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+      context);
+}
+
+void BatchMatmulTransposeBOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+      context);
+}
+
+void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
+}
+
+void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 43410aaa6af1b..b4b36ba0bfe51 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -518,6 +518,7 @@ def batch_matmul(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -537,6 +538,7 @@ def batch_matmul_transpose_a(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
@@ -556,6 +558,7 @@ def batch_matmul_transpose_b(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -642,6 +645,7 @@ def batch_matvec(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.m, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -660,6 +664,7 @@ def batch_vecmat(
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
+    defines(Canonicalizer)
     domain(D.b, D.n, D.k)
     implements(ContractionOpInterface)
     C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(

>From 02b2ca083d145fc88a9498480e4a831affdebf10 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 16 Jun 2024 12:01:33 -0500
Subject: [PATCH 2/3] add tests

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  34 +++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  12 +-
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 137 +++++++++++++++++-
 3 files changed, 174 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3cbfb58ed8506..41f90483c93b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,40 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: round
   cpp_class_name: RoundOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ecd669165efc7..4e47b6018c445 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -605,16 +605,12 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
     auto lhsType = cast<ShapedType>(lhs.getType());
     auto rhsType = cast<ShapedType>(rhs.getType());
     auto initType = cast<ShapedType>(init.getType());
-    if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
-        lhsType.getShape()[0] != rhsType.getShape()[0] ||
-        rhsType.getShape()[0] != initType.getShape()[0])
-      return rewriter.notifyMatchFailure(
-          batchMatmulOp, "expected batch sizes of all operands to be same");
+    if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
+        initType.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
 
     auto results = batchMatmulOp.getResults();
-    if (results.size() > 1)
-      return rewriter.notifyMatchFailure(batchMatmulOp,
-                                         "expected at most one result");
+    assert(results.size() < 2 && "expected at most one result");
 
     SmallVector<Type, 1> resultType;
     if (results.size() == 1) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..8514bcb089891 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
   return %0 : tensor<2x3xf32>
 }
 
-// ----
+// -----
 
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,3 +1096,138 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+// -----
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+  // CHECK-LABEL: @singleton_batch_matmul_tensor
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-DAG:    %[[C2:.*]] = arith.constant 2
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+  return %1 : tensor<1x?x?xf32>
+}
+
+// -----
+
+func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+  // CHECK-LABEL: @singletone_batch_matmul_memref
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+      outs(%arg2 : memref<1x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singletone_batch_matvec
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+  // CHECK-LABEL: @singletone_batch_vecmat
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+  //  CHECK-NEXT:   return %[[RES]]
+  %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+      outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singletone_batchmatmul_transpose_a
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+  // CHECK-LABEL: @singletone_batchmatmul_transpose_b
+  //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+  //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+  //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+  //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+  //  CHECK-NEXT:    linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+  //  CHECK-NEXT:   return
+  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+      outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+  return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+  // CHECK-NOT:   collapse_shape
+  // CHECK:       linalg.batch_matmul
+  // CHECK-NOT:   expand_shape
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+

>From 543b0d643506d12c658e2984d943873ed4c8b78b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 16 Jun 2024 12:12:35 -0500
Subject: [PATCH 3/3] remove unecessary changes

---
 .../mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml        | 1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp                        | 2 --
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 41f90483c93b3..3f0aa33767a75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -309,6 +309,7 @@ metadata: !LinalgOpMetadata
   cpp_class_name: ReciprocalOp
   doc: |-
     Applies reciprocal(x) elementwise.
+
     No numeric casting is performed on the input operand.
 structured_op: !LinalgStructuredOpConfig
   args:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4e47b6018c445..8df33a107c2cb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -43,7 +42,6 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
-#include <numeric>
 #include <optional>
 
 using namespace mlir;



More information about the Mlir-commits mailing list