[Mlir-commits] [mlir] 08f0cb7 - [mlir] Prevent crash in DropUnitDim pattern due to tensor with encoding
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 17 12:03:26 PDT 2021
Author: thomasraoux
Date: 2021-09-17T12:03:16-07:00
New Revision: 08f0cb77197dc2842baa00f22f0264fa49d1475a
URL: https://github.com/llvm/llvm-project/commit/08f0cb77197dc2842baa00f22f0264fa49d1475a
DIFF: https://github.com/llvm/llvm-project/commit/08f0cb77197dc2842baa00f22f0264fa49d1475a.diff
LOG: [mlir] Prevent crash in DropUnitDim pattern due to tensor with encoding
Differential Revision: https://reviews.llvm.org/D109984
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 8315de4c72e7..98a06bfda97a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -361,6 +361,12 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
+ // Skip the pattern if the op has any tensor with special encoding.
+ if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
+ auto tensorType = type.dyn_cast<RankedTensorType>();
+ return tensorType && tensorType.getEncoding() != nullptr;
+ }))
+ return failure();
MLIRContext *context = rewriter.getContext();
Location loc = genericOp.getLoc();
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 60ad72300a18..53bdf0aa712a 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -796,3 +796,34 @@ func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: me
// CHECK: linalg.yield %[[ARG]] : f32
// CHECK: }
// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>
+
+// -----
+
+// Negative test for case with tensor encoding.
+#matvec = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (j)>, // b
+ affine_map<(i,j) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel", "reduction"]
+}
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }>
+
+func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+ %0 = linalg.init_tensor [8] : tensor<8xf32>
+ %1 = linalg.generic #matvec
+ ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>)
+ outs(%0: tensor<8xf32>) {
+ ^bb(%a: f32, %b: f32, %x: f32):
+ %m = mulf %a, %b : f32
+ %add = addf %x, %m : f32
+ linalg.yield %add : f32
+ } -> tensor<8xf32>
+ return %1: tensor<8xf32>
+}
+
+// CHECK-LABEL: func @sparse_case
+// CHECK-NEXT: linalg.init_tensor
+// CHECK-NEXT: linalg.generic
More information about the Mlir-commits
mailing list