[Mlir-commits] [mlir] 2ada5cb - [mlir][linalg] Fix bug in InferStaticShapeOfOperands pattern
Vladislav Vinogradov
llvmlistbot at llvm.org
Wed Nov 16 01:19:40 PST 2022
Author: Vladislav Vinogradov
Date: 2022-11-16T12:19:16+03:00
New Revision: 2ada5cbea47bdf7d8df70981315f352237cc2222
URL: https://github.com/llvm/llvm-project/commit/2ada5cbea47bdf7d8df70981315f352237cc2222
DIFF: https://github.com/llvm/llvm-project/commit/2ada5cbea47bdf7d8df70981315f352237cc2222.diff
LOG: [mlir][linalg] Fix bug in InferStaticShapeOfOperands pattern
The pattern tries to deduce static shape from `tensor.cast` producer of linalg operation operands.
The original code unconditionally casts type of the `tensor.cast` source to `RankedTensorType`.
But the `tensor.cast` can also operate on `UnrankedTensorType`, so this cast either fail on assertion
in debug build or introduce UB in release build.
The patch replaces unconditional cast with `dyn_cast` and check for the cast result.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D137775
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 32c8dd65678bf..63e2ef38554c5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2236,8 +2236,8 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
if (parentOp) {
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
Value castSource = castOp.getSource();
- auto castSourceType = castSource.getType().cast<RankedTensorType>();
- if (castSourceType.hasStaticShape())
+ auto castSourceType = castSource.getType().dyn_cast<RankedTensorType>();
+ if (castSourceType && castSourceType.hasStaticShape())
sourceShape = castSourceType.getShape();
}
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 3f1118334ef5c..1fe5fe50b2c20 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -47,7 +47,6 @@ func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tenso
// -----
-
// CHECK-LABEL: func @tensor.cast(
func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
-> tensor<3x?xf32>
@@ -68,6 +67,30 @@ func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3
// -----
+// CHECK-LABEL: func @tensor.cast.unranked(
+func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>)
+ -> tensor<*xf32>
+{
+ // CHECK: tensor.cast
+ // CHECK: tensor.cast
+ // CHECK: tensor.cast
+ %ta = tensor.cast %a : tensor<*xf32> to tensor<?x?xf32>
+ %tb = tensor.cast %b : tensor<*xf32> to tensor<?x?xf32>
+ %tc = tensor.cast %c : tensor<*xf32> to tensor<?x?xf32>
+
+ // CHECK: linalg.matmul ins({{.*}}tensor<?x?xf32>, tensor<?x?xf32>)
+ // CHECK-SAME: outs({{.*}}tensor<?x?xf32>) -> tensor<?x?xf32>
+ %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // CHECK: tensor.cast
+ %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<*xf32>
+
+ return %1: tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @linalg_effects(
// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32>
More information about the Mlir-commits
mailing list