[Mlir-commits] [mlir] 08f0cb7 - [mlir] Prevent crash in DropUnitDim pattern due to tensor with encoding

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 17 12:03:26 PDT 2021


Author: thomasraoux
Date: 2021-09-17T12:03:16-07:00
New Revision: 08f0cb77197dc2842baa00f22f0264fa49d1475a

URL: https://github.com/llvm/llvm-project/commit/08f0cb77197dc2842baa00f22f0264fa49d1475a
DIFF: https://github.com/llvm/llvm-project/commit/08f0cb77197dc2842baa00f22f0264fa49d1475a.diff

LOG: [mlir] Prevent crash in DropUnitDim pattern due to tensor with encoding

Differential Revision: https://reviews.llvm.org/D109984

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 8315de4c72e7..98a06bfda97a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -361,6 +361,12 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
+    // Skip the pattern if the op has any tensor with special encoding.
+    if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
+          auto tensorType = type.dyn_cast<RankedTensorType>();
+          return tensorType && tensorType.getEncoding() != nullptr;
+        }))
+      return failure();
     MLIRContext *context = rewriter.getContext();
     Location loc = genericOp.getLoc();
 

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 60ad72300a18..53bdf0aa712a 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -796,3 +796,34 @@ func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: me
 // CHECK:       linalg.yield %[[ARG]] : f32
 // CHECK:      }
 // CHECK:      return %[[ARG2]] : memref<?x1x?x1x?xf32>
+
+// -----
+
+// Negative test for case with tensor encoding.
+#matvec = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>, // A
+    affine_map<(i,j) -> (j)>,   // b
+    affine_map<(i,j) -> (i)>    // x (out)
+  ],
+  iterator_types = ["parallel", "reduction"]
+}
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }>
+
+func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+    %0 = linalg.init_tensor [8] : tensor<8xf32>
+    %1 = linalg.generic #matvec
+      ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>)
+      outs(%0: tensor<8xf32>) {
+      ^bb(%a: f32, %b: f32, %x: f32):
+        %m = mulf %a, %b : f32
+        %add = addf %x, %m : f32
+        linalg.yield %add : f32
+    } -> tensor<8xf32>
+    return %1: tensor<8xf32>
+}
+
+// CHECK-LABEL: func @sparse_case
+//  CHECK-NEXT:   linalg.init_tensor
+//  CHECK-NEXT:   linalg.generic


        


More information about the Mlir-commits mailing list