[Mlir-commits] [mlir] 8a583bd - [mlir][sparse] Add codegen for expand op.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 8 14:06:09 PDT 2022
Author: bixia1
Date: 2022-09-08T14:06:01-07:00
New Revision: 8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc
URL: https://github.com/llvm/llvm-project/commit/8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc
DIFF: https://github.com/llvm/llvm-project/commit/8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc.diff
LOG: [mlir][sparse] Add codegen for expand op.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133454
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index f7f4a39a95f2..8d2e06942497 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -169,6 +169,7 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
let dependentDialects = [
"arith::ArithmeticDialect",
"bufferization::BufferizationDialect",
+ "linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 5eca3c86fe03..9ad37bf498f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -474,6 +475,58 @@ class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
}
};
+/// Sparse codegen rule for the expand op.
+class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
+ Type eltType = srcType.getElementType();
+ Type boolType = rewriter.getIntegerType(1);
+ Type idxType = rewriter.getIndexType();
+ // All initialization should be done on entry of the loop nest.
+ rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
+ // Determine the size for access expansion (always the innermost stored
+ // dimension size, translated back to original dimension). Note that we
+ // recursively rewrite the new DimOp on the **original** tensor.
+ auto enc = getSparseTensorEncoding(srcType);
+ unsigned innerDim = srcType.getRank() - 1;
+ if (AffineMap p = enc.getDimOrdering())
+ innerDim = p.getDimPosition(innerDim);
+ Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
+ // Generate a memref for `sz` elements of type `t`.
+ auto genAlloc = [&](Type t) {
+ auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
+ return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
+ };
+ // Allocate temporary buffers for values, filled-switch, and indices.
+ // We do not use stack buffers for this, since the expanded size may
+ // be rather large (as it envelops a single expanded dense dimension).
+ Value values = genAlloc(eltType);
+ Value filled = genAlloc(boolType);
+ Value indices = genAlloc(idxType);
+ Value zero = constantZero(rewriter, loc, idxType);
+ // Reset the values/filled-switch to all-zero/false. Note that this
+ // introduces an O(N) operation into the computation, but this reset
+ // operation is amortized over the innermost loops for the access
+ // pattern expansion. As noted in the operation doc, we would like
+ // to amortize this setup cost even between kernels.
+ rewriter.create<linalg::FillOp>(
+ loc, ValueRange{constantZero(rewriter, loc, eltType)},
+ ValueRange{values});
+ rewriter.create<linalg::FillOp>(
+ loc, ValueRange{constantZero(rewriter, loc, boolType)},
+ ValueRange{filled});
+ // Replace expansion op with these buffers and initial index.
+ assert(op.getNumResults() == 4);
+ rewriter.replaceOp(op, {values, filled, indices, zero});
+ return success();
+ }
+};
+
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
@@ -533,8 +586,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
- SparseCastConverter, SparseTensorAllocConverter,
- SparseTensorDeallocConverter, SparseToPointersConverter,
- SparseToIndicesConverter, SparseToValuesConverter,
- SparseTensorLoadConverter>(typeConverter, patterns.getContext());
+ SparseCastConverter, SparseExpandConverter,
+ SparseTensorAllocConverter, SparseTensorDeallocConverter,
+ SparseToPointersConverter, SparseToIndicesConverter,
+ SparseToValuesConverter, SparseTensorLoadConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index fee4222cb53d..ebb6993a767f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -175,7 +175,9 @@ struct SparseTensorCodegenPass
[&](bufferization::DeallocTensorOp op) {
return converter.isLegal(op.getTensor().getType());
});
- // Legal dialects may occur in generated code.
+ // The following operations and dialects may be introduced by the
+ // codegen rules, and are therefore marked as legal.
+ target.addLegalOp<linalg::FillOp>();
target.addLegalDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect, scf::SCFDialect>();
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 667a5e1e4e23..a2bd75429d48 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -286,3 +286,53 @@ func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
return %1 : tensor<10x20x30xf64, #Dense3D>
}
+
+// CHECK-LABEL: func.func @sparse_expansion1()
+// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64>
+// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1>
+// CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex>
+// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref<?xindex>
+// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>)
+// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>)
+// CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion1() -> memref<?xindex> {
+ %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
+ %values, %filled, %added, %count = sparse_tensor.expand %0
+ : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return %added : memref<?xindex>
+}
+
+// CHECK-LABEL: func.func @sparse_expansion2()
+// CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64>
+// CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1>
+// CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex>
+// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref<?xindex>
+// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>)
+// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>)
+// CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion2() -> memref<?xindex> {
+ %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
+ %values, %filled, %added, %count = sparse_tensor.expand %0
+ : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return %added : memref<?xindex>
+}
+
+// CHECK-LABEL: func.func @sparse_expansion3(
+// CHECK-SAME: %[[D0:.*]]: index,
+// CHECK-SAME: %{{.*}}: index) -> memref<?xindex> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex>
+// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
+// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
+// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref<?xf64>
+// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref<?xi1>
+// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref<?xindex>
+// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref<?xf64>)
+// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
+// CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
+ %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
+ %values, %filled, %added, %count = sparse_tensor.expand %0
+ : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return %added : memref<?xindex>
+}
More information about the Mlir-commits
mailing list