[Mlir-commits] [mlir] f1b9720 - [mlir][Linalg] Start a LinalgToStandard pass and move conversion to library calls.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu May 14 21:35:09 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-15T00:24:03-04:00
New Revision: f1b972041adf565c08d1abca41068d2adcf62702
URL: https://github.com/llvm/llvm-project/commit/f1b972041adf565c08d1abca41068d2adcf62702
DIFF: https://github.com/llvm/llvm-project/commit/f1b972041adf565c08d1abca41068d2adcf62702.diff
LOG: [mlir][Linalg] Start a LinalgToStandard pass and move conversion to library calls.
This revision starts decoupling the include the kitchen sink behavior of Linalg to LLVM lowering by inserting a -convert-linalg-to-std pass.
The lowering of linalg ops to function calls was previously lowering to memref descriptors by having both linalg -> std and std -> LLVM patterns in the same rewrite.
When separating this step, a new issue occurred: the layout is automatically type-erased by this process. This revision therefore introduces memref casts to perform these type erasures explicitly. To connect everything end-to-end, the LLVM lowering of MemRefCastOp is relaxed because it is artificially more restricted than the op semantics. The op semantics already guarantee that source and target MemRefTypes are cast-compatible. An invalid lowering test now becomes valid and is removed.
Differential Revision: https://reviews.llvm.org/D79468
Added:
mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/test/Dialect/Linalg/standard.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/test/Conversion/StandardToLLVM/invalid.mlir
mlir/test/Dialect/Linalg/llvm.mlir
mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
new file mode 100644
index 000000000000..6585eaf35ef6
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -0,0 +1,29 @@
+//===- LinalgToStandard.h - Utils to convert from the linalg dialect ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
+#define MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class MLIRContext;
+class ModuleOp;
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Linalg to Standard.
+void populateLinalgToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
+/// Create a pass to convert Linalg operations to the Standard dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a14e8a43e2d9..f0724b2d1677 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -141,6 +141,16 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
let constructor = "mlir::createConvertLinalgToLLVMPass()";
}
+//===----------------------------------------------------------------------===//
+// LinalgToStandard
+//===----------------------------------------------------------------------===//
+
+def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
+ let summary = "Convert the operations from the linalg dialect into the "
+ "Standard dialect";
+ let constructor = "mlir::createConvertLinalgToStandardPass()";
+}
+
//===----------------------------------------------------------------------===//
// LinalgToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 3348c13bb4ea..305cfaecb2d3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1745,12 +1745,16 @@ def MemRefCastOp : CastOp<"memref_cast"> {
The `memref_cast` operation converts a memref from one type to an equivalent
type with a compatible shape. The source and destination types are
compatible if:
- a. both are ranked memref types with the same element type, affine mappings,
- address space, and rank but where the individual dimensions may add or
- remove constant dimensions from the memref type.
+
+ a. Both are ranked memref types with the same element type, address space,
+ and rank and:
+ 1. Both have the same layout or both have compatible strided layouts.
+ 2. The individual sizes (resp. offset and strides in the case of strided
+ memrefs) may convert constant dimensions to dynamic dimensions and
+ vice-versa.
If the cast converts any dimensions from an unknown to a known size, then it
- acts as an assertion that fails at runtime of the dynamic dimensions
+ acts as an assertion that fails at runtime if the dynamic dimensions
disagree with resultant destination size.
Example:
@@ -1772,7 +1776,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
memref<12x4xf32, offset:?, strides: [?, ?]>
```
- b. either or both memref types are unranked with the same element type, and
+ b. Either or both memref types are unranked with the same element type, and
address space.
Example:
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 752f2a45f07d..99ad3b10f6d9 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -723,6 +723,10 @@ AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
/// `t` with simplified layout.
MemRefType canonicalizeStridedLayout(MemRefType t);
+/// Return a version of `t` with a layout that has all dynamic offset and
+/// strides. This is used to erase the static layout.
+MemRefType eraseStridedLayout(MemRefType t);
+
/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index ef8fc8ceddd6..26c11621e0ac 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -22,6 +22,7 @@
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
+#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 85869f3c6629..3ac3b11b8298 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -7,6 +7,7 @@ add_subdirectory(GPUToSPIRV)
add_subdirectory(GPUToVulkan)
add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
+add_subdirectory(LinalgToStandard)
add_subdirectory(LoopsToGPU)
add_subdirectory(LoopToStandard)
add_subdirectory(StandardToLLVM)
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 3b3bf0f08370..68ac974b5e96 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -349,205 +349,6 @@ class YieldOpConversion : public ConvertToLLVMPattern {
};
} // namespace
-template <typename LinalgOp>
-static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
- return SmallVector<Type, 4>{op->getOperandTypes()};
-}
-
-template <>
-SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
- auto ctx = op->getContext();
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
- auto numLoops = indexedGenericOp.getNumLoops();
-
- SmallVector<Type, 4> result;
- result.reserve(numLoops + op->getNumOperands());
- for (unsigned i = 0; i < numLoops; ++i) {
- result.push_back(IndexType::get(ctx));
- }
- for (auto type : op->getOperandTypes()) {
- result.push_back(type);
- }
- return result;
-}
-
-// Get a SymbolRefAttr containing the library function name for the LinalgOp.
-// If the library function does not exist, insert a declaration.
-template <typename LinalgOp>
-static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
- PatternRewriter &rewriter) {
- auto linalgOp = cast<LinalgOp>(op);
- auto fnName = linalgOp.getLibraryCallName();
- if (fnName.empty()) {
- op->emitWarning("No library call defined for: ") << *op;
- return {};
- }
-
- // fnName is a dynamic std::String, unique it via a SymbolRefAttr.
- FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
- auto module = op->getParentOfType<ModuleOp>();
- if (module.lookupSymbol(fnName)) {
- return fnNameAttr;
- }
-
- SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
- assert(op->getNumResults() == 0 &&
- "Library call for linalg operation can be generated only for ops that "
- "have void return types");
- auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
-
- OpBuilder::InsertionGuard guard(rewriter);
- // Insert before module terminator.
- rewriter.setInsertionPoint(module.getBody(),
- std::prev(module.getBody()->end()));
- FuncOp funcOp =
- rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
- ArrayRef<NamedAttribute>{});
- // Insert a function attribute that will trigger the emission of the
- // corresponding `_mlir_ciface_xxx` interface so that external libraries see
- // a normalized ABI. This interface is added during std to llvm conversion.
- funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
- return fnNameAttr;
-}
-
-namespace {
-
-// LinalgOpConversion<LinalgOp> creates a new call to the
-// `LinalgOp::getLibraryCallName()` function.
-// The implementation of the function can be either in the same module or in an
-// externally linked library.
-template <typename LinalgOp>
-class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
-public:
- using OpRewritePattern<LinalgOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override {
- auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- rewriter.replaceOpWithNewOp<mlir::CallOp>(
- op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
- return success();
- }
-};
-
-/// Conversion pattern specialization for CopyOp. This kicks in when both input
-/// and output permutations are left unspecified or are the identity.
-template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
-public:
- using OpRewritePattern<CopyOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CopyOp op,
- PatternRewriter &rewriter) const override {
- auto inputPerm = op.inputPermutation();
- if (inputPerm.hasValue() && !inputPerm->isIdentity())
- return failure();
- auto outputPerm = op.outputPermutation();
- if (outputPerm.hasValue() && !outputPerm->isIdentity())
- return failure();
-
- auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- rewriter.replaceOpWithNewOp<mlir::CallOp>(
- op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
- return success();
- }
-};
-
-/// Conversion pattern specialization for IndexedGenericOp.
-template <>
-class LinalgOpConversion<IndexedGenericOp>
- : public OpRewritePattern<IndexedGenericOp> {
-public:
- using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IndexedGenericOp op,
- PatternRewriter &rewriter) const override {
- auto libraryCallName =
- getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- // TODO(pifon, ntv): Use induction variables values instead of zeros, when
- // IndexedGenericOp is tiled.
- auto zero = rewriter.create<mlir::ConstantOp>(
- op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
- auto numLoops = indexedGenericOp.getNumLoops();
- SmallVector<Value, 4> operands;
- operands.reserve(numLoops + op.getNumOperands());
- for (unsigned i = 0; i < numLoops; ++i) {
- operands.push_back(zero);
- }
- for (auto operand : op.getOperands()) {
- operands.push_back(operand);
- }
- rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
- ArrayRef<Type>{}, operands);
- return success();
- }
-};
-
-/// A non-conversion rewrite pattern kicks in to convert CopyOp with
-/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
-/// This interplays together with TransposeOpConversion and
-/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
-class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
-public:
- using OpRewritePattern<CopyOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CopyOp op,
- PatternRewriter &rewriter) const override {
- Value in = op.input(), out = op.output();
-
- // If either inputPerm or outputPerm are non-identities, insert transposes.
- auto inputPerm = op.inputPermutation();
- if (inputPerm.hasValue() && !inputPerm->isIdentity())
- in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
- AffineMapAttr::get(*inputPerm));
- auto outputPerm = op.outputPermutation();
- if (outputPerm.hasValue() && !outputPerm->isIdentity())
- out = rewriter.create<linalg::TransposeOp>(
- op.getLoc(), out, AffineMapAttr::get(*outputPerm));
-
- // If nothing was transposed, fail and let the conversion kick in.
- if (in == op.input() && out == op.output())
- return failure();
-
- rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
- return success();
- }
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-static void
-populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
- // attribute values such as kernel striding and dilation.
- // clang-format off
- patterns.insert<
- CopyTransposeConversion,
- LinalgOpConversion<ConvOp>,
- LinalgOpConversion<PoolingMaxOp>,
- LinalgOpConversion<PoolingMinOp>,
- LinalgOpConversion<PoolingSumOp>,
- LinalgOpConversion<CopyOp>,
- LinalgOpConversion<DotOp>,
- LinalgOpConversion<FillOp>,
- LinalgOpConversion<GenericOp>,
- LinalgOpConversion<IndexedGenericOp>,
- LinalgOpConversion<MatmulOp>,
- LinalgOpConversion<MatvecOp>>(ctx);
- // clang-format on
-}
-
-} // namespace
-
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
@@ -579,7 +380,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
populateVectorToLoopsConversionPatterns(patterns, &getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
- populateLinalgToStandardConversionPatterns(patterns, &getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
new file mode 100644
index 000000000000..84f19499bf71
--- /dev/null
+++ b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRLinalgToStandard
+ LinalgToStandard.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
+
+ DEPENDS
+ MLIRConversionPassIncGen
+)
+
+target_link_libraries(MLIRLinalgToStandard
+ PUBLIC
+ MLIREDSC
+ MLIRIR
+ MLIRLinalgOps
+ MLIRSCF
+ LLVMCore
+ LLVMSupport
+ )
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
new file mode 100644
index 000000000000..ca6ca8b24732
--- /dev/null
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -0,0 +1,271 @@
+//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Helper function to extract the operand types that are passed to the
+/// generated CallOp. MemRefTypes have their layout canonicalized since the
+/// information is not used in signature generation.
+/// Note that static size information is not modified.
+template <typename LinalgOp>
+static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
+ SmallVector<Type, 4> result;
+ result.reserve(op->getNumOperands());
+ for (auto type : op->getOperandTypes()) {
+ // The underlying descriptor type (e.g. LLVM) does not have layout
+ // information. Canonicalizing the type at the level of std when going into
+ // a library call avoids needing to introduce DialectCastOp.
+ if (auto memrefType = type.dyn_cast<MemRefType>())
+ result.push_back(eraseStridedLayout(memrefType));
+ else
+ result.push_back(type);
+ }
+ return result;
+}
+
+template <>
+SmallVector<Type, 4> extractOperandTypes<IndexedGenericOp>(Operation *op) {
+ auto *ctx = op->getContext();
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+
+ SmallVector<Type, 4> result(numLoops, IndexType::get(ctx));
+ auto canonicalizedOperands = extractOperandTypes<LinalgOp>(op);
+ result.append(canonicalizedOperands.begin(), canonicalizedOperands.end());
+ return result;
+}
+
+// Get a SymbolRefAttr containing the library function name for the LinalgOp.
+// If the library function does not exist, insert a declaration.
+template <typename LinalgOp>
+static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
+ PatternRewriter &rewriter) {
+ auto linalgOp = cast<LinalgOp>(op);
+ auto fnName = linalgOp.getLibraryCallName();
+ if (fnName.empty()) {
+ op->emitWarning("No library call defined for: ") << *op;
+ return {};
+ }
+
+ // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
+ FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+ auto module = op->getParentOfType<ModuleOp>();
+ if (module.lookupSymbol(fnName)) {
+ return fnNameAttr;
+ }
+
+ SmallVector<Type, 4> inputTypes(extractOperandTypes<LinalgOp>(op));
+ assert(op->getNumResults() == 0 &&
+ "Library call for linalg operation can be generated only for ops that "
+ "have void return types");
+ auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Insert before module terminator.
+ rewriter.setInsertionPoint(module.getBody(),
+ std::prev(module.getBody()->end()));
+ FuncOp funcOp =
+ rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
+ ArrayRef<NamedAttribute>{});
+ // Insert a function attribute that will trigger the emission of the
+ // corresponding `_mlir_ciface_xxx` interface so that external libraries see
+ // a normalized ABI. This interface is added during std to llvm conversion.
+ funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
+ return fnNameAttr;
+}
+
+namespace {
+
+SmallVector<Value, 4>
+createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
+ ValueRange operands) {
+ SmallVector<Value, 4> res;
+ res.reserve(operands.size());
+ for (auto op : operands) {
+ auto memrefType = op.getType().dyn_cast<MemRefType>();
+ if (!memrefType) {
+ res.push_back(op);
+ continue;
+ }
+ Value cast =
+ b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
+ res.push_back(cast);
+ }
+ return res;
+}
+
+// LinalgOpConversion<LinalgOp> creates a new call to the type-canonicalized
+// `LinalgOp::getLibraryCallName()` function.
+// The implementation of the function can be either in the same module or in an
+// externally linked library.
+template <typename LinalgOp>
+class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
+public:
+ using OpRewritePattern<LinalgOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
+ if (!libraryCallName)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), ArrayRef<Type>{},
+ createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
+ op.getOperands()));
+ return success();
+ }
+};
+
+/// Conversion pattern specialization for CopyOp. This kicks in when both input
+/// and output permutations are left unspecified or are the identity.
+template <>
+class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ return failure();
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ return failure();
+
+ auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
+ if (!libraryCallName)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), ArrayRef<Type>{},
+ createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
+ op.getOperands()));
+ return success();
+ }
+};
+
+/// Conversion pattern specialization for IndexedGenericOp.
+template <>
+class LinalgOpConversion<IndexedGenericOp>
+ : public OpRewritePattern<IndexedGenericOp> {
+public:
+ using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexedGenericOp op,
+ PatternRewriter &rewriter) const override {
+ auto libraryCallName =
+ getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
+ if (!libraryCallName)
+ return failure();
+
+ // TODO(pifon, ntv): Use induction variables values instead of zeros, when
+ // IndexedGenericOp is tiled.
+ auto zero = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+ SmallVector<Value, 4> operands;
+ operands.reserve(numLoops + op.getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i)
+ operands.push_back(zero);
+ for (auto operand : op.getOperands())
+ operands.push_back(operand);
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), ArrayRef<Type>{},
+ createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
+ return success();
+ }
+};
+
+/// A non-conversion rewrite pattern kicks in to convert CopyOp with
+/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
+/// This interplays together with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
+ Value in = op.input(), out = op.output();
+
+ // If either inputPerm or outputPerm are non-identities, insert transposes.
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
+ AffineMapAttr::get(*inputPerm));
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ out = rewriter.create<linalg::TransposeOp>(
+ op.getLoc(), out, AffineMapAttr::get(*outputPerm));
+
+ // If nothing was transposed, fail and let the conversion kick in.
+ if (in == op.input() && out == op.output())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+ return success();
+ }
+};
+} // namespace
+
+/// Populate the given list with patterns that convert from Linalg to Standard.
+void mlir::populateLinalgToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
+ // attribute values such as kernel striding and dilation.
+ // clang-format off
+ patterns.insert<
+ CopyTransposeConversion,
+ LinalgOpConversion<ConvOp>,
+ LinalgOpConversion<PoolingMaxOp>,
+ LinalgOpConversion<PoolingMinOp>,
+ LinalgOpConversion<PoolingSumOp>,
+ LinalgOpConversion<CopyOp>,
+ LinalgOpConversion<DotOp>,
+ LinalgOpConversion<FillOp>,
+ LinalgOpConversion<GenericOp>,
+ LinalgOpConversion<IndexedGenericOp>,
+ LinalgOpConversion<MatmulOp>,
+ LinalgOpConversion<MatvecOp>>(ctx);
+ // clang-format on
+}
+
+namespace {
+struct ConvertLinalgToStandardPass
+ : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertLinalgToStandardPass::runOnOperation() {
+ auto module = getOperation();
+ ConversionTarget target(getContext());
+ target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
+ target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
+ target.addLegalOp<linalg::TransposeOp, linalg::ReshapeOp, linalg::RangeOp>();
+ OwningRewritePatternList patterns;
+ populateLinalgToStandardConversionPatterns(patterns, &getContext());
+ if (failed(applyFullConversion(module, target, patterns)))
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertLinalgToStandardPass() {
+ return std::make_unique<ConvertLinalgToStandardPass>();
+}
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 66fe763d88bd..8218976418c5 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2003,15 +2003,14 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
- if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
- MemRefType sourceType =
- memRefCastOp.getOperand().getType().cast<MemRefType>();
- MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
- return (isSupportedMemRefType(targetType) &&
- isSupportedMemRefType(sourceType))
- ? success()
- : failure();
- }
+ // MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used
+ // for type erasure. For now they must preserve underlying element type and
+ // require source and result type to have the same rank. Therefore, perform
+ // a sanity check that the underlying structs are the same. Once op
+ // semantics are relaxed we can revisit.
+ if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
+ return success(typeConverter.convertType(srcType) ==
+ typeConverter.convertType(dstType));
// At least one of the operands is unranked type
assert(srcType.isa<UnrankedMemRefType>() ||
@@ -2034,10 +2033,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
auto loc = op->getLoc();
+ // MemRefCastOp reduce to bitcast in the ranked MemRef case.
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
- // memref_cast is defined for source and destination memref types with the
- // same element type, same mappings, same address space and same rank.
- // Therefore a simple bitcast suffices. If not it is undefined behavior.
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
transformed.source());
} else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 3f4a7ec6ef6e..808b4fc910d2 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -764,6 +764,14 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
return simplifyAffineExpr(expr, numDims, nSymbols);
}
+/// Return a version of `t` with a layout that has all dynamic offset and
+/// strides. This is used to erase the static layout.
+MemRefType mlir::eraseStridedLayout(MemRefType t) {
+ auto val = ShapedType::kDynamicStrideOrOffset;
+ return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
+ SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
+}
+
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context) {
SmallVector<AffineExpr, 4> exprs;
diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
index 1be148707458..56e661242336 100644
--- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
@@ -1,18 +1,5 @@
// RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file
-#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-
-func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
- %c1 = constant 1 : index
- %c0 = constant 0 : index
- // expected-error at +1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}}
- %5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
- %25 = std.subview %5[%c0, %c0][%c1, %c1][1, 1] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
- return
-}
-
-// -----
-
func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 {
// expected-error at +1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}}
%1 = llvm.mlir.cast %0 : index to !llvm.i64
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 377775cd9dc6..9b052fd2fab4 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -1,5 +1,4 @@
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefix=LLVM-LOOPS
func @range(%arg0: index) {
%c0 = constant 0 : index
@@ -48,16 +47,6 @@ func @slice(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: !linalg.range)
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
-func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
- linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
- return
-}
-// CHECK-LABEL: func @dot
-// CHECK: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}) :
-// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64
-
func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?, 1]>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -80,15 +69,6 @@ func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?,
// CHECK: llvm.insertvalue %{{.*}}[3, 0] : !llvm<"{ double*, double*, i64, [1 x i64], [1 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}[4, 0] : !llvm<"{ double*, double*, i64, [1 x i64], [1 x i64] }">
-func @copy(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.copy(%arg0, %arg1) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
- return
-}
-// CHECK-LABEL: func @copy
-// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32({{.*}}) :
-// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
return
@@ -105,115 +85,6 @@ func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.copy(%arg0, %arg1) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>,
- outputPermutation = affine_map<(i, j, k) -> (k, j, i)>}
- : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
- return
-}
-// CHECK-LABEL: func @copy
-// Transpose input
-// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// Transpose output
-// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// Call external copy.
-// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32
-
-#matmul_accesses = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
- args_in = 2,
- args_out = 1,
- iterator_types = ["parallel", "parallel", "reduction"],
- indexing_maps = #matmul_accesses,
- library_call = "external_outerproduct_matmul"
-}
-
-!vector_type_A = type vector<4xf32>
-!vector_type_B = type vector<4xf32>
-!vector_type_C = type vector<4x4xf32>
-
-!matrix_type_A = type memref<?x?x!vector_type_A>
-!matrix_type_B = type memref<?x?x!vector_type_B>
-!matrix_type_C = type memref<?x?x!vector_type_C>
-
-func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
- linalg.generic #matmul_trait %A, %B, %C {
- ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
- %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
- linalg.yield %d: !vector_type_C
- } : !matrix_type_A, !matrix_type_B, !matrix_type_C
-
- return
-}
-// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) :
-// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-
-// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
-// LLVM-LOOPS-SAME: %[[A:.*0]]: memref<?x?xvector<4xf32>>,
-// LLVM-LOOPS-SAME: %[[B:.*1]]: memref<?x?xvector<4xf32>>,
-// LLVM-LOOPS-SAME: %[[C:.*2]]: memref<?x?xvector<4x4xf32>>)
-// LLVM-LOOPS: %[[C0:.*]] = constant 0 : index
-// LLVM-LOOPS: %[[C1:.*]] = constant 1 : index
-// LLVM-LOOPS: %[[T0:.*]] = dim %[[A]], 0 : memref<?x?xvector<4xf32>>
-// LLVM-LOOPS: %[[T1:.*]] = dim %[[A]], 1 : memref<?x?xvector<4xf32>>
-// LLVM-LOOPS: %[[T2:.*]] = dim %[[B]], 1 : memref<?x?xvector<4xf32>>
-// LLVM-LOOPS: scf.for %[[I:.*]] = %[[C0]] to %[[T0]] step %[[C1]] {
-// LLVM-LOOPS: scf.for %[[J:.*]] = %[[C0]] to %[[T2]] step %[[C1]] {
-// LLVM-LOOPS: scf.for %[[K:.*]] = %[[C0]] to %[[T1]] step %[[C1]] {
-// LLVM-LOOPS: %[[T3:.*]] = load %[[A]][%[[I]], %[[K]]] : memref<?x?xvector<4xf32>>
-// LLVM-LOOPS: %[[T4:.*]] = load %[[B]][%[[K]], %[[J]]] : memref<?x?xvector<4xf32>>
-// LLVM-LOOPS: %[[T5:.*]] = load %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
-// LLVM-LOOPS: %[[T6:.*]] = vector.outerproduct %3, %4, %5 : vector<4xf32>, vector<4xf32>
-// LLVM-LOOPS: store %[[T6]], %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
-
-#indexed_matmul_trait = {
- args_in = 2,
- args_out = 1,
- iterator_types = ["parallel", "parallel", "reduction"],
- indexing_maps = #matmul_accesses,
- library_call = "external_indexed_outerproduct_matmul"
-}
-func @matmul_vec_indexed(%A: !matrix_type_A,
- %B: !matrix_type_B,
- %C: !matrix_type_C) {
- linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
- ^bb0(%i: index, %j: index, %k: index,
- %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
- %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
- linalg.yield %d: !vector_type_C
- } : !matrix_type_A, !matrix_type_B, !matrix_type_C
- return
-}
-// CHECK-LABEL: func @matmul_vec_indexed(
-// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}) :
-// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-
func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
new file mode 100644
index 000000000000..b94c504434ed
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt %s -convert-linalg-to-std | FileCheck %s
+
+// CHECK-DAG: #[[map0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
+// CHECK-DAG: #[[map2:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d2 * s2 + d1)>
+// CHECK-DAG: #[[map3:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+// CHECK-DAG: #[[map4:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
+// CHECK-DAG: #[[map5:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
+// CHECK-DAG: #[[map6:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-DAG: #[[map7:.*]] = affine_map<()[s0] -> (s0)>
+// CHECK-DAG: #[[map8:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+
+func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>,
+ %arg1: memref<?xf32, offset: ?, strides: [1]>,
+ %arg2: memref<f32>) {
+ linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<f32>
+ return
+}
+// CHECK-LABEL: func @dot(
+// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, #[[map0]]>,
+// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, #[[map0]]>,
+// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
+// CHECK: %[[o0:.*]] = memref_cast %[[arg0]] :
+// CHECK-SAME: memref<?xf32, #[[map0]]> to memref<?xf32, #[[map6]]>
+// CHECK: %[[o1:.*]] = memref_cast %[[arg1]] :
+// CHECK-SAME: memref<?xf32, #[[map0]]> to memref<?xf32, #[[map6]]>
+// CHECK: %[[o2:.*]] = memref_cast %[[arg2]] :
+// CHECK-SAME: memref<f32> to memref<f32, #[[map7]]>
+// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
+// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
+// CHECK-SAME: memref<?xf32, #[[map6]]>, memref<?xf32, #[[map6]]>, memref<f32, #[[map7]]>
+
+func @copy(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ linalg.copy(%arg0, %arg1) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ return
+}
+// CHECK-LABEL: func @copy(
+// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[map1]]>,
+// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[map1]]>) {
+// CHECK: %[[o0:.*]] = memref_cast %[[arg0]] :
+// CHECK-SAME: memref<?x?x?xf32, #[[map1]]> to memref<?x?x?xf32, #[[map8]]>
+// CHECK: %[[o1:.*]] = memref_cast %[[arg1]] :
+// CHECK-SAME: memref<?x?x?xf32, #[[map1]]> to memref<?x?x?xf32, #[[map8]]>
+// CHECK: call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%[[o0]], %[[o1]]) :
+// CHECK-SAME: memref<?x?x?xf32, #[[map8]]>, memref<?x?x?xf32, #[[map8]]>
+
+func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ linalg.copy(%arg0, %arg1) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>,
+ outputPermutation = affine_map<(i, j, k) -> (k, j, i)>}
+ : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ return
+}
+// CHECK-LABEL: func @copy_transpose(
+// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[map1]]>,
+// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[map1]]>) {
+// CHECK: %[[t0:.*]] = linalg.transpose %[[arg0]]
+// CHECK-SAME: (d0, d1, d2) -> (d0, d2, d1) : memref<?x?x?xf32, #[[map1]]>
+// CHECK: %[[t1:.*]] = linalg.transpose %[[arg1]]
+// CHECK-SAME: (d0, d1, d2) -> (d2, d1, d0) : memref<?x?x?xf32, #[[map1]]>
+// CHECK: %[[o0:.*]] = memref_cast %[[t0]] :
+// CHECK-SAME: memref<?x?x?xf32, #[[map2]]> to memref<?x?x?xf32, #[[map8]]>
+// CHECK: %[[o1:.*]] = memref_cast %[[t1]] :
+// CHECK-SAME: memref<?x?x?xf32, #[[map4]]> to memref<?x?x?xf32, #[[map8]]>
+// CHECK: call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%[[o0]], %[[o1]]) :
+// CHECK-SAME: memref<?x?x?xf32, #[[map8]]>, memref<?x?x?xf32, #[[map8]]>
+
+#matmul_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+ args_in = 2,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ indexing_maps = #matmul_accesses,
+ library_call = "external_outerproduct_matmul"
+}
+
+!vector_type_A = type vector<4xf32>
+!vector_type_B = type vector<4xf32>
+!vector_type_C = type vector<4x4xf32>
+
+!matrix_type_A = type memref<?x?x!vector_type_A>
+!matrix_type_B = type memref<?x?x!vector_type_B>
+!matrix_type_C = type memref<?x?x!vector_type_C>
+
+func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
+ linalg.generic #matmul_trait %A, %B, %C {
+ ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+ %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+ linalg.yield %d: !vector_type_C
+ } : !matrix_type_A, !matrix_type_B, !matrix_type_C
+
+ return
+}
+// CHECK-LABEL: func @matmul_vec_impl(
+// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
+
+#indexed_matmul_trait = {
+ args_in = 2,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ indexing_maps = #matmul_accesses,
+ library_call = "external_indexed_outerproduct_matmul"
+}
+func @matmul_vec_indexed(%A: !matrix_type_A,
+ %B: !matrix_type_B,
+ %C: !matrix_type_C) {
+ linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
+ ^bb0(%i: index, %j: index, %k: index,
+ %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+ %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+ linalg.yield %d: !vector_type_C
+ } : !matrix_type_A, !matrix_type_B, !matrix_type_C
+ return
+}
+// CHECK-LABEL: func @matmul_vec_indexed(
+// CHECK: %[[ZERO:.*]] = constant 0 : index
+// CHECK: call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}})
diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
index e8030c870605..9400792837c4 100644
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
@@ -1,24 +1,24 @@
-// RUN: mlir-opt %s -convert-linalg-to-llvm \
+// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \
// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm \
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-llvm \
+// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \
// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm \
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-llvm \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-llvm \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-std -convert-linalg-to-llvm \
// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
// RUN: | FileCheck %s
More information about the Mlir-commits
mailing list