[Mlir-commits] [mlir] 8a583bd - [mlir][sparse] Add codegen for expand op.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 8 14:06:09 PDT 2022


Author: bixia1
Date: 2022-09-08T14:06:01-07:00
New Revision: 8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc

URL: https://github.com/llvm/llvm-project/commit/8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc
DIFF: https://github.com/llvm/llvm-project/commit/8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc.diff

LOG: [mlir][sparse] Add codegen for expand op.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133454

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index f7f4a39a95f2..8d2e06942497 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -169,6 +169,7 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   let dependentDialects = [
     "arith::ArithmeticDialect",
     "bufferization::BufferizationDialect",
+    "linalg::LinalgDialect",
     "memref::MemRefDialect",
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 5eca3c86fe03..9ad37bf498f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -19,6 +19,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -474,6 +475,58 @@ class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
   }
 };
 
+/// Sparse codegen rule for the expand op.
+class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
+    Type eltType = srcType.getElementType();
+    Type boolType = rewriter.getIntegerType(1);
+    Type idxType = rewriter.getIndexType();
+    // All initialization should be done on entry of the loop nest.
+    rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
+    // Determine the size for access expansion (always the innermost stored
+    // dimension size, translated back to original dimension). Note that we
+    // recursively rewrite the new DimOp on the **original** tensor.
+    auto enc = getSparseTensorEncoding(srcType);
+    unsigned innerDim = srcType.getRank() - 1;
+    if (AffineMap p = enc.getDimOrdering())
+      innerDim = p.getDimPosition(innerDim);
+    Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
+    // Generate a memref for `sz` elements of type `t`.
+    auto genAlloc = [&](Type t) {
+      auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
+      return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
+    };
+    // Allocate temporary buffers for values, filled-switch, and indices.
+    // We do not use stack buffers for this, since the expanded size may
+    // be rather large (as it envelops a single expanded dense dimension).
+    Value values = genAlloc(eltType);
+    Value filled = genAlloc(boolType);
+    Value indices = genAlloc(idxType);
+    Value zero = constantZero(rewriter, loc, idxType);
+    // Reset the values/filled-switch to all-zero/false. Note that this
+    // introduces an O(N) operation into the computation, but this reset
+    // operation is amortized over the innermost loops for the access
+    // pattern expansion. As noted in the operation doc, we would like
+    // to amortize this setup cost even between kernels.
+    rewriter.create<linalg::FillOp>(
+        loc, ValueRange{constantZero(rewriter, loc, eltType)},
+        ValueRange{values});
+    rewriter.create<linalg::FillOp>(
+        loc, ValueRange{constantZero(rewriter, loc, boolType)},
+        ValueRange{filled});
+    // Replace expansion op with these buffers and initial index.
+    assert(op.getNumResults() == 4);
+    rewriter.replaceOp(op, {values, filled, indices, zero});
+    return success();
+  }
+};
+
 /// Sparse codegen rule for pointer accesses.
 class SparseToPointersConverter
     : public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
@@ -533,8 +586,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                                RewritePatternSet &patterns) {
   patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
-               SparseCastConverter, SparseTensorAllocConverter,
-               SparseTensorDeallocConverter, SparseToPointersConverter,
-               SparseToIndicesConverter, SparseToValuesConverter,
-               SparseTensorLoadConverter>(typeConverter, patterns.getContext());
+               SparseCastConverter, SparseExpandConverter,
+               SparseTensorAllocConverter, SparseTensorDeallocConverter,
+               SparseToPointersConverter, SparseToIndicesConverter,
+               SparseToValuesConverter, SparseTensorLoadConverter>(
+               typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index fee4222cb53d..ebb6993a767f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -175,7 +175,9 @@ struct SparseTensorCodegenPass
         [&](bufferization::DeallocTensorOp op) {
           return converter.isLegal(op.getTensor().getType());
         });
-    // Legal dialects may occur in generated code.
+    // The following operations and dialects may be introduced by the
+    // codegen rules, and are therefore marked as legal.
+    target.addLegalOp<linalg::FillOp>();
     target.addLegalDialect<arith::ArithmeticDialect,
                            bufferization::BufferizationDialect,
                            memref::MemRefDialect, scf::SCFDialect>();

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 667a5e1e4e23..a2bd75429d48 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -286,3 +286,53 @@ func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
   %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
   return %1 : tensor<10x20x30xf64, #Dense3D>
 }
+
+//   CHECK-LABEL: func.func @sparse_expansion1()
+//         CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64>
+//         CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1>
+//         CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex>
+//         CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref<?xindex>
+//     CHECK-DAG: linalg.fill ins(%{{.*}}  : f64) outs(%[[A]] : memref<8xf64>)
+//     CHECK-DAG: linalg.fill ins(%{{.*}}  : i1) outs(%[[B]] : memref<8xi1>)
+//         CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion1() -> memref<?xindex> {
+  %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
+  %values, %filled, %added, %count = sparse_tensor.expand %0
+    : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return %added : memref<?xindex>
+}
+
+//   CHECK-LABEL: func.func @sparse_expansion2()
+//         CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64>
+//         CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1>
+//         CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex>
+//         CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref<?xindex>
+//     CHECK-DAG: linalg.fill ins(%{{.*}}  : f64) outs(%[[A]] : memref<4xf64>)
+//     CHECK-DAG: linalg.fill ins(%{{.*}}  : i1) outs(%[[B]] : memref<4xi1>)
+//         CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion2() -> memref<?xindex> {
+  %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
+  %values, %filled, %added, %count = sparse_tensor.expand %0
+    : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return %added : memref<?xindex>
+}
+
+//   CHECK-LABEL: func.func @sparse_expansion3(
+//    CHECK-SAME: %[[D0:.*]]: index,
+//    CHECK-SAME: %{{.*}}: index) -> memref<?xindex> {
+//         CHECK: %[[C1:.*]] = arith.constant 1 : index
+//         CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex>
+//         CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
+//         CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
+//         CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref<?xf64>
+//         CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref<?xi1>
+//         CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref<?xindex>
+//         CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref<?xf64>)
+//         CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
+//         CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
+  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
+  %values, %filled, %added, %count = sparse_tensor.expand %0
+    : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return %added : memref<?xindex>
+}


        


More information about the Mlir-commits mailing list