[Mlir-commits] [mlir] [mlir] move LinalgToStandard to Linalg as ConvertToFunctionCalls (PR #121392)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Dec 31 06:15:10 PST 2024


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/121392

The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing.

Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little.

>From 2448cdf9ff52e8451eceee5e412612fb71398b76 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Tue, 31 Dec 2024 15:10:56 +0100
Subject: [PATCH] [mlir] move LinalgToStandard to Linalg as
 ConvertToFunctionCalls

The remnants of the ConvertLinalgToStandard pass were still present in the
codebase under this name, years after the Standard dialect was dismantled.
Practically, this pass / pattern set was only performing the rewrite of Linalg
operaitons to function calls. All this makes the existence of the pass highly
confusing.

Move the logic under Linalg/Transforms, similarly to other "lowerings" from
Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard
to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically,
one of them was called library-calls.mlir. Simplify the code a little.
---
 .../LinalgToStandard/LinalgToStandard.h       | 27 +-----
 mlir/include/mlir/Conversion/Passes.td        | 11 ---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  7 ++
 .../Dialect/Linalg/Transforms/Transforms.h    | 26 ++++++
 mlir/lib/Conversion/CMakeLists.txt            |  1 -
 .../LinalgToStandard/CMakeLists.txt           | 23 -----
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  2 +
 .../Linalg/Transforms/FunctionCalls.cpp}      | 28 +++----
 ...library-calls.mlir => function-calls.mlir} | 84 ++++++++++++++++++-
 mlir/test/Dialect/Linalg/standard.mlir        | 81 ------------------
 10 files changed, 131 insertions(+), 159 deletions(-)
 delete mode 100644 mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
 rename mlir/lib/{Conversion/LinalgToStandard/LinalgToStandard.cpp => Dialect/Linalg/Transforms/FunctionCalls.cpp} (87%)
 rename mlir/test/Dialect/Linalg/{library-calls.mlir => function-calls.mlir} (61%)
 delete mode 100644 mlir/test/Dialect/Linalg/standard.mlir

diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index eefa2c4724833b..346cf62cdb8e86 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -22,34 +22,11 @@ class OperationPass;
 
 namespace linalg {
 
-//===----------------------------------------------------------------------===//
-// Patterns to convert a LinalgOp to func.call @external library implementation.
-//===----------------------------------------------------------------------===//
-// These patterns are exposed individually because they are expected to be
-// typically used individually.
-
-// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
-// function. The implementation of the function can be either in the same module
-// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp, for which
-// more specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite
-    : public OpInterfaceRewritePattern<LinalgOp> {
-public:
-  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
-
-} // namespace linalg
-
 /// Create a pass to convert Linalg operations to the Standard dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
 
+} // namespace linalg
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 58ee87cf820396..7a3ffa97bd5212 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// LinalgToStandard
-//===----------------------------------------------------------------------===//
-
-def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
-  let summary = "Convert the operations from the linalg dialect into the "
-                "Standard dialect";
-  let constructor = "mlir::createConvertLinalgToStandardPass()";
-  let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
-}
-
 //===----------------------------------------------------------------------===//
 // MathToLibm
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0a..99c6d1c14674a1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass
   ];
 }
 
+def ConvertLinalgToFunctionCallsPass
+    : Pass<"convert-linalg-to-function-calls", "ModuleOp"> {
+  let summary = "Convert the operations from the Linalg dialect into "
+                "function calls";
+  let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"];
+}
+
 def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let options = [
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..1ae27136512873 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to func.call @external library implementation.
+//
+// These patterns are exposed individually because they are expected to be
+// typically used individually.
+//===----------------------------------------------------------------------===//
+
+// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp, for which
+// more specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite
+    : public OpInterfaceRewritePattern<LinalgOp> {
+public:
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Populates the given list with patterns that convert from Linalg to library
+/// calls using the `func` dialect.
+void populateLinalgToFunctionCallsConversionPatterns(
+    RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 62461c0cea08af..1c7318bb584d45 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
 add_subdirectory(IndexToSPIRV)
-add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
deleted file mode 100644
index 7fc4af54031855..00000000000000
--- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-add_mlir_conversion_library(MLIRLinalgToStandard
-  LinalgToStandard.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRFuncDialect
-  MLIRIR
-  MLIRLinalgDialect
-  MLIRLinalgTransforms
-  MLIRLLVMDialect
-  MLIRMemRefDialect
-  MLIRPass
-  MLIRSCFDialect
-  MLIRTransforms
-  )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..d6bdf1d52dd1da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
+  FunctionCalls.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
@@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRMeshTransforms
   MLIRLinalgDialect
   MLIRLinalgUtils
+  MLIRLLVMDialect
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRPass
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
similarity index 87%
rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
index 4d1f35c767304d..a202dac0aa2326 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
@@ -1,4 +1,4 @@
-//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
-#include "mlir/Conversion/Passes.h.inc"
+#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
 } // namespace mlir
 
 using namespace mlir;
@@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
   return success();
 }
 
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::linalg::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns(
     RewritePatternSet &patterns) {
   // TODO: ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
@@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
 }
 
 namespace {
-struct ConvertLinalgToStandardPass
-    : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
+struct ConvertLinalgToFunctionCallsPass
+    : public impl::ConvertLinalgToFunctionCallsPassBase<
+          ConvertLinalgToFunctionCallsPass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void ConvertLinalgToStandardPass::runOnOperation() {
+void ConvertLinalgToFunctionCallsPass::runOnOperation() {
   auto module = getOperation();
   ConversionTarget target(getContext());
   target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
@@ -146,12 +145,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
                          scf::SCFDialect>();
   target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
   RewritePatternSet patterns(&getContext());
-  populateLinalgToStandardConversionPatterns(patterns);
+  populateLinalgToFunctionCallsConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertLinalgToStandardPass() {
-  return std::make_unique<ConvertLinalgToStandardPass>();
-}
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir
similarity index 61%
rename from mlir/test/Dialect/Linalg/library-calls.mlir
rename to mlir/test/Dialect/Linalg/function-calls.mlir
index 1fa675d8b4b68a..103fcb16c51732 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/function-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s
 
 func.func private @printMemrefF32(memref<*xf32>)
 
@@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8
                               ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
   return
 }
+
+// -----
+
+func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
+          %arg1: memref<?xf32, strided<[1], offset: ?>>,
+          %arg2: memref<f32>) {
+  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
+                               memref<?xf32, strided<[1], offset: ?>>)
+             outs(%arg2: memref<f32>)
+  return
+}
+// CHECK-LABEL: func @dot(
+//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
+//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
+//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
+//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
+//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
+//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
+//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
+//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
+
+// -----
+
+#matmul_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+  iterator_types = ["parallel", "parallel", "reduction"],
+  indexing_maps = #matmul_accesses,
+  library_call = "external_outerproduct_matmul"
+}
+
+!vector_type_A = vector<4xf32>
+!vector_type_B = vector<4xf32>
+!vector_type_C = vector<4x4xf32>
+
+!matrix_type_A = memref<?x?x!vector_type_A>
+!matrix_type_B = memref<?x?x!vector_type_B>
+!matrix_type_C = memref<?x?x!vector_type_C>
+
+func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
+  linalg.generic #matmul_trait
+      ins(%A, %B : !matrix_type_A, !matrix_type_B)
+     outs(%C : !matrix_type_C) {
+    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+      linalg.yield %d: !vector_type_C
+  }
+  return
+}
+// CHECK-LABEL: func @matmul_vec_impl(
+// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.generic {
+    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
+  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+  ^bb0(%in: f32, %out: f32): 
+    linalg.yield %in : f32
+  } -> tensor<?xf32>
+  return 
+}
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
deleted file mode 100644
index f50016f9ea477f..00000000000000
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s
-
-func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
-          %arg1: memref<?xf32, strided<[1], offset: ?>>,
-          %arg2: memref<f32>) {
-  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
-                               memref<?xf32, strided<[1], offset: ?>>)
-             outs(%arg2: memref<f32>)
-  return
-}
-// CHECK-LABEL: func @dot(
-//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
-//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
-//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
-//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
-//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
-//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
-//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
-//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
-
-// -----
-
-#matmul_accesses = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
-  iterator_types = ["parallel", "parallel", "reduction"],
-  indexing_maps = #matmul_accesses,
-  library_call = "external_outerproduct_matmul"
-}
-
-!vector_type_A = vector<4xf32>
-!vector_type_B = vector<4xf32>
-!vector_type_C = vector<4x4xf32>
-
-!matrix_type_A = memref<?x?x!vector_type_A>
-!matrix_type_B = memref<?x?x!vector_type_B>
-!matrix_type_C = memref<?x?x!vector_type_C>
-
-func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
-  linalg.generic #matmul_trait
-      ins(%A, %B : !matrix_type_A, !matrix_type_B)
-     outs(%C : !matrix_type_C) {
-    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
-      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
-      linalg.yield %d: !vector_type_C
-  }
-  return
-}
-// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-
-func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
-  // expected-error @below {{failed to legalize}}
-  %0 = linalg.generic {
-    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
-  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
-  ^bb0(%in: f32, %out: f32): 
-    linalg.yield %in : f32
-  } -> tensor<?xf32>
-  return 
-}
-
-// -----
-
-func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // expected-error @below {{failed to legalize}}
-  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0 : tensor<4x8xf32>
-}



More information about the Mlir-commits mailing list