[Mlir-commits] [mlir] d22df0e - [mlir][sparse] refine insertion code
Aart Bik
llvmlistbot at llvm.org
Tue Oct 18 14:16:48 PDT 2022
Author: Aart Bik
Date: 2022-10-18T14:16:38-07:00
New Revision: d22df0ebba3f28700248c6de8a9d302628d0d2fb
URL: https://github.com/llvm/llvm-project/commit/d22df0ebba3f28700248c6de8a9d302628d0d2fb
DIFF: https://github.com/llvm/llvm-project/commit/d22df0ebba3f28700248c6de8a9d302628d0d2fb.diff
LOG: [mlir][sparse] refine insertion code
builds SSA cycle for compress insertion loop
adds casting on index mismatch during push_back
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D136186
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 5e5815b60061b..77bfef7c7d8aa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -265,11 +265,14 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
}
/// Creates a straightforward counting for-loop.
-static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
+static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
+ SmallVectorImpl<Value> &fields) {
Type indexType = builder.getIndexType();
Value zero = constantZero(builder, loc, indexType);
Value one = constantOne(builder, loc, indexType);
- scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one);
+ scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one, fields);
+ for (unsigned i = 0, e = fields.size(); i < e; i++)
+ fields[i] = forOp.getRegionIterArg(i);
builder.setInsertionPointToStart(forOp.getBody());
return forOp;
}
@@ -280,6 +283,9 @@ static void createPushback(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &fields, unsigned field,
Value value) {
assert(field < fields.size());
+ Type etp = fields[field].getType().cast<ShapedType>().getElementType();
+ if (value.getType() != etp)
+ value = builder.create<arith::IndexCastOp>(loc, etp, value);
fields[field] =
builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
fields[field], value, APInt(64, field));
@@ -298,11 +304,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
!isOrderedDim(rtp, 0))
return; // TODO: add codegen
- // push_back memSizes pointers-0 0
// push_back memSizes indices-0 index
// push_back memSizes values value
- Value zero = constantIndex(builder, loc, 0);
- createPushback(builder, loc, fields, 2, zero);
createPushback(builder, loc, fields, 3, indices[0]);
createPushback(builder, loc, fields, 4, value);
}
@@ -316,9 +319,12 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
!isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
return; // TODO: add codegen
+ // push_back memSizes pointers-0 0
// push_back memSizes pointers-0 memSizes[2]
+ Value zero = constantIndex(builder, loc, 0);
Value two = constantIndex(builder, loc, 2);
Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
+ createPushback(builder, loc, fields, 2, zero);
createPushback(builder, loc, fields, 2, size);
}
@@ -460,6 +466,7 @@ class SparseTensorAllocConverter
Location loc = op.getLoc();
SmallVector<Value, 8> fields;
createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
+ // Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
}
@@ -504,6 +511,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), srcType, fields);
+ // Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
return success();
}
@@ -591,23 +599,26 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
// sparsity of the expanded access pattern.
//
// Generate
- // for (i = 0; i < count; i++) {
+ // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
// index = added[i];
// value = values[index];
// insert({prev_indices, index}, value);
+ // new_memrefs = insert(in_memrefs, {prev_indices, index}, value);
// values[index] = 0;
// filled[index] = false;
+ // yield new_memrefs
// }
- Value i = createFor(rewriter, loc, count).getInductionVar();
+ scf::ForOp loop = createFor(rewriter, loc, count, fields);
+ Value i = loop.getInductionVar();
Value index = rewriter.create<memref::LoadOp>(loc, added, i);
Value value = rewriter.create<memref::LoadOp>(loc, values, index);
indices.push_back(index);
- // TODO: generate yield cycle
genInsert(rewriter, loc, dstType, fields, indices, value);
rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
values, index);
rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
filled, index);
+ rewriter.create<scf::YieldOp>(loc, fields);
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = op;
for (; isa<scf::ForOp>(parent->getParentOp()) ||
@@ -620,7 +631,9 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<memref::DeallocOp>(loc, values);
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
- rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
+ // Replace operation with resulting memrefs.
+ rewriter.replaceOp(op,
+ genTuple(rewriter, loc, dstType, loop->getResults()));
return success();
}
};
@@ -641,6 +654,7 @@ class SparseInsertConverter : public OpConversionPattern<InsertOp> {
// Generate insertion.
Value value = adaptor.getValue();
genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+ // Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
return success();
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index ecdf0e45beb55..b469e669f8218 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -354,6 +354,49 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
return %added : memref<?xindex>
}
+// CHECK-LABEL: func @sparse_compression_1d(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[A8:.*8]]: index)
+// CHECK-DAG: %[[B0:.*]] = arith.constant false
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
+// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[T1:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+// CHECK: %[[T2:.*]] = memref.load %[[A5]][%[[T1]]] : memref<?xf64>
+// CHECK: %[[T3:.*]] = sparse_tensor.push_back %[[A1]], %[[P0]], %[[T1]] {idx = 3 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[T2]] {idx = 4 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: memref.store %[[F0]], %arg5[%[[T1]]] : memref<?xf64>
+// CHECK: memref.store %[[B0]], %arg6[%[[T1]]] : memref<?xi1>
+// CHECK: scf.yield %[[T3]], %[[T4]] : memref<?xindex>, memref<?xf64>
+// CHECK: }
+// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
+// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
+// CHECK: %[[LL:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] {idx = 2 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[LL]] {idx = 2 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK: return %[[A0]], %[[A1]], %[[P2]], %[[R]]#0, %[[R]]#1 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
+ %values: memref<?xf64>,
+ %filled: memref<?xi1>,
+ %added: memref<?xindex>,
+ %count: index) -> tensor<100xf64, #SV> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<100xf64, #SV>
+ %1 = sparse_tensor.load %0 hasInserts : tensor<100xf64, #SV>
+ return %1 : tensor<100xf64, #SV>
+}
+
// CHECK-LABEL: func @sparse_compression(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -372,7 +415,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-// TODO: insert
+// TODO: 2D-insert
// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
// CHECK-NEXT: }
@@ -388,7 +431,8 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%i: index) -> tensor<8x8xf64, #CSR> {
%0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
- return %0 : tensor<8x8xf64, #CSR>
+ %1 = sparse_tensor.load %0 hasInserts : tensor<8x8xf64, #CSR>
+ return %1 : tensor<8x8xf64, #CSR>
}
// CHECK-LABEL: func @sparse_compression_unordered(
@@ -409,7 +453,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
// CHECK-NOT: sparse_tensor.sort
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-// TODO: insert
+// TODO: 2D-insert
// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
// CHECK-NEXT: }
@@ -425,7 +469,8 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
%i: index) -> tensor<8x8xf64, #UCSR> {
%0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
- return %0 : tensor<8x8xf64, #UCSR>
+ %1 = sparse_tensor.load %0 hasInserts : tensor<8x8xf64, #UCSR>
+ return %1 : tensor<8x8xf64, #UCSR>
}
// CHECK-LABEL: func @sparse_insert(
@@ -438,10 +483,10 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
// CHECK-SAME: %[[A6:.*6]]: f64)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]]
// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]]
// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
@@ -449,3 +494,27 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
return %1 : tensor<128xf64, #SV>
}
+
+// CHECK-LABEL: func @sparse_insert_typed(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[S1:.*]] = arith.index_cast %[[A5]] : index to i64
+// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[S1]]
+// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
+// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
+// CHECK: %[[S2:.*]] = arith.index_cast %[[T3]] : index to i32
+// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[S2]]
+// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+ %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
+ return %1 : tensor<128xf64, #SparseVector>
+}
More information about the Mlir-commits
mailing list