[Mlir-commits] [mlir] ff94419 - [mlir][Linalg] Fix crash in LinalgToStandard
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jan 20 00:19:43 PST 2023
Author: Nicolas Vasilache
Date: 2023-01-20T00:17:55-08:00
New Revision: ff94419a287c0b20bf357ab85cf611d4e9bad4c0
URL: https://github.com/llvm/llvm-project/commit/ff94419a287c0b20bf357ab85cf611d4e9bad4c0
DIFF: https://github.com/llvm/llvm-project/commit/ff94419a287c0b20bf357ab85cf611d4e9bad4c0.diff
LOG: [mlir][Linalg] Fix crash in LinalgToStandard
Properly handle `appendMangledType` failure instead of asserting.
Fixes #59986.
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e411de66da067..ebfb8ef1fe57c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1795,7 +1795,7 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
return llvm::to_vector<4>(concatRanges);
}
-static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
+static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto memref = t.dyn_cast<MemRefType>()) {
ss << "view";
for (auto size : memref.getShape())
@@ -1804,16 +1804,19 @@ static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
else
ss << size << "x";
appendMangledType(ss, memref.getElementType());
- } else if (auto vec = t.dyn_cast<VectorType>()) {
+ return success();
+ }
+ if (auto vec = t.dyn_cast<VectorType>()) {
ss << "vector";
llvm::interleave(
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
appendMangledType(ss, vec.getElementType());
+ return success();
} else if (t.isSignlessIntOrIndexOrFloat()) {
ss << t;
- } else {
- llvm_unreachable("Invalid type for linalg library name mangling");
+ return success();
}
+ return failure();
}
std::string mlir::linalg::generateLibraryCallName(Operation *op) {
@@ -1823,11 +1826,14 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
std::replace(name.begin(), name.end(), '.', '_');
llvm::raw_string_ostream ss(name);
ss << "_";
- auto types = op->getOperandTypes();
- llvm::interleave(
- types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
- [&]() { ss << "_"; });
- return ss.str();
+ for (Type t : op->getOperandTypes()) {
+ if (failed(appendMangledType(ss, t)))
+ return std::string();
+ ss << "_";
+ }
+ std::string res = ss.str();
+ res.pop_back();
+ return res;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
index fcb215e6035aa..f50016f9ea477 100644
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -71,3 +71,11 @@ func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
} -> tensor<?xf32>
return
}
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // expected-error @below {{failed to legalize}}
+ %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
More information about the Mlir-commits
mailing list