[Mlir-commits] [mlir] 55270f5 - [mlir][sparse] fix a bug in unpack op that used wrong compare predicate.

Peiming Liu llvmlistbot at llvm.org
Wed Mar 8 11:52:15 PST 2023


Author: Peiming Liu
Date: 2023-03-08T19:52:09Z
New Revision: 55270f56d2a0992e9aa238fad5bdae03537d1032

URL: https://github.com/llvm/llvm-project/commit/55270f56d2a0992e9aa238fad5bdae03537d1032
DIFF: https://github.com/llvm/llvm-project/commit/55270f56d2a0992e9aa238fad5bdae03537d1032.diff

LOG: [mlir][sparse] fix a bug in unpack op that used wrong compare predicate.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145603

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1a6c8ea8654a0..80f299692af1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -578,7 +578,8 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
 
   Value targetLen = constantIndex(builder, loc, len);
   Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
-  Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+  // Reallocates if target length is greater than the actual buffer len.
+  Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
                                                  targetLen, bufferLen);
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
   // If targetLen > bufferLen, reallocate to get enough sparse to return.

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 494ced26d8aeb..99befbeb2f1a5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -43,7 +43,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
 // CHECK:           %[[VAL_4:.*]] = arith.constant 6 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
-// CHECK:           %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
 // CHECK:           %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
 // CHECK:             %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
 // CHECK:             scf.yield %[[VAL_9]] : memref<6xf64>
@@ -53,7 +53,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
 // CHECK:           }
 // CHECK:           %[[VAL_11:.*]] = arith.constant 12 : index
 // CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
-// CHECK:           %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_13:.*]] = arith.cmpi ugt, %[[VAL_11]], %[[VAL_12]] : index
 // CHECK:           %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
 // CHECK:             %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
 // CHECK:             scf.yield %[[VAL_15]] : memref<12xi32>


        


More information about the Mlir-commits mailing list