[Mlir-commits] [mlir] 55270f5 - [mlir][sparse] fix a bug in unpack op that used wrong compare predicate.
Peiming Liu
llvmlistbot at llvm.org
Wed Mar 8 11:52:15 PST 2023
Author: Peiming Liu
Date: 2023-03-08T19:52:09Z
New Revision: 55270f56d2a0992e9aa238fad5bdae03537d1032
URL: https://github.com/llvm/llvm-project/commit/55270f56d2a0992e9aa238fad5bdae03537d1032
DIFF: https://github.com/llvm/llvm-project/commit/55270f56d2a0992e9aa238fad5bdae03537d1032.diff
LOG: [mlir][sparse] fix a bug in unpack op that used wrong compare predicate.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D145603
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1a6c8ea8654a0..80f299692af1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -578,7 +578,8 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
Value targetLen = constantIndex(builder, loc, len);
Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
- Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ // Reallocates if target length is greater than the actual buffer len.
+ Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
targetLen, bufferLen);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
// If targetLen > bufferLen, reallocate to get enough sparse to return.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 494ced26d8aeb..99befbeb2f1a5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -43,7 +43,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
-// CHECK: %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index
+// CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
// CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
// CHECK: scf.yield %[[VAL_9]] : memref<6xf64>
@@ -53,7 +53,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
// CHECK: }
// CHECK: %[[VAL_11:.*]] = arith.constant 12 : index
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
-// CHECK: %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK: %[[VAL_13:.*]] = arith.cmpi ugt, %[[VAL_11]], %[[VAL_12]] : index
// CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
// CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
// CHECK: scf.yield %[[VAL_15]] : memref<12xi32>
More information about the Mlir-commits
mailing list