[Mlir-commits] [mlir] 2ada5cb - [mlir][linalg] Fix bug in InferStaticShapeOfOperands pattern

Vladislav Vinogradov llvmlistbot at llvm.org
Wed Nov 16 01:19:40 PST 2022


Author: Vladislav Vinogradov
Date: 2022-11-16T12:19:16+03:00
New Revision: 2ada5cbea47bdf7d8df70981315f352237cc2222

URL: https://github.com/llvm/llvm-project/commit/2ada5cbea47bdf7d8df70981315f352237cc2222
DIFF: https://github.com/llvm/llvm-project/commit/2ada5cbea47bdf7d8df70981315f352237cc2222.diff

LOG: [mlir][linalg] Fix bug in InferStaticShapeOfOperands pattern

The pattern tries to deduce static shape from `tensor.cast` producer of linalg operation operands.
The original code unconditionally casts type of the `tensor.cast` source to `RankedTensorType`.
But the `tensor.cast` can also operate on `UnrankedTensorType`, so this cast either fail on assertion
in debug build or introduce UB in release build.

The patch replaces unconditional cast with `dyn_cast` and check for the cast result.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D137775

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 32c8dd65678bf..63e2ef38554c5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2236,8 +2236,8 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
     if (parentOp) {
       if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
         Value castSource = castOp.getSource();
-        auto castSourceType = castSource.getType().cast<RankedTensorType>();
-        if (castSourceType.hasStaticShape())
+        auto castSourceType = castSource.getType().dyn_cast<RankedTensorType>();
+        if (castSourceType && castSourceType.hasStaticShape())
           sourceShape = castSourceType.getShape();
       }
     }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 3f1118334ef5c..1fe5fe50b2c20 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -47,7 +47,6 @@ func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tenso
 
 // -----
 
-
 // CHECK-LABEL: func @tensor.cast(
 func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
   -> tensor<3x?xf32>
@@ -68,6 +67,30 @@ func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3
 
 // -----
 
+// CHECK-LABEL: func @tensor.cast.unranked(
+func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>)
+  -> tensor<*xf32>
+{
+  //      CHECK:  tensor.cast
+  //      CHECK:  tensor.cast
+  //      CHECK:  tensor.cast
+  %ta = tensor.cast %a : tensor<*xf32> to tensor<?x?xf32>
+  %tb = tensor.cast %b : tensor<*xf32> to tensor<?x?xf32>
+  %tc = tensor.cast %c : tensor<*xf32> to tensor<?x?xf32>
+
+  //      CHECK:  linalg.matmul ins({{.*}}tensor<?x?xf32>, tensor<?x?xf32>)
+  // CHECK-SAME:    outs({{.*}}tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  //      CHECK:  tensor.cast
+  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<*xf32>
+
+  return %1: tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @linalg_effects(
 //  CHECK-SAME:     %[[A:[a-z0-9]*]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[B:[a-z0-9]*]]: memref<?x?xf32>


        


More information about the Mlir-commits mailing list