[Mlir-commits] [mlir] 928b5b0 - [mlir][sparse] add conversion rules for storage_get/set/callOp

Peiming Liu llvmlistbot at llvm.org
Fri Sep 2 11:28:04 PDT 2022


Author: Peiming Liu
Date: 2022-09-02T18:27:54Z
New Revision: 928b5b06f94dd95350bf9df298845bdecf60fa40

URL: https://github.com/llvm/llvm-project/commit/928b5b06f94dd95350bf9df298845bdecf60fa40
DIFF: https://github.com/llvm/llvm-project/commit/928b5b06f94dd95350bf9df298845bdecf60fa40.diff

LOG: [mlir][sparse] add conversion rules for storage_get/set/callOp

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133175

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
    mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 538041a1b36a6..9057921e87c68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -217,10 +217,12 @@ struct SparseTensorStorageExpansionPass
     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
       return converter.isLegal(op.getOperandTypes());
     });
+    // We generate UnrealizedConversionCastOp to intermix tuples and a
+    // list of types.
+    target.addLegalOp<UnrealizedConversionCastOp>();
     // Populate with rules and apply rewriting rules.
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
                                                                    converter);
-    populateCallOpTypeConversionPattern(patterns, converter);
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     populateSparseTensorStorageExpansionPatterns(converter, patterns);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
index c1305eba79009..31370e32cb063 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
@@ -41,10 +41,69 @@ convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
   return llvm::None;
 }
 
+/// Flatten a list of operands that may contain tuples.
+static void flattenOperands(ValueRange operands,
+                            SmallVectorImpl<Value> &flattened) {
+  // In case of
+  // tuple<a, b>, c, tuple<d, e>
+  // ==>
+  // a, b, c, d, e
+  for (auto operand : operands) {
+    if (auto cast =
+            dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+        cast && cast->getResultTypes()[0].isa<TupleType>())
+      // An unrealized_conversion_cast will be inserted by type converter to
+      // inter-mix the gap between 1:N conversion between tuple and types.
+      // In this case, take the operands in the cast and replace the tuple
+      // output with the flattened type array.
+      flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+    else
+      flattened.push_back(operand);
+  }
+}
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
 
+/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
+class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto castOp =
+        cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
+    uint64_t idx = op.getIdx().getZExtValue();
+    assert(idx < castOp.getOperands().size());
+
+    rewriter.replaceOp(op, castOp.getOperand(idx));
+    return success();
+  }
+};
+
+/// Sparse tensor storage conversion rule for sparse_tensor::storage_set.
+class SparseStorageSetConverter : public OpConversionPattern<StorageSetOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(StorageSetOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto castOp =
+        cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
+    uint64_t idx = op.getIdx().getZExtValue();
+
+    SmallVector<Value, 8> values(castOp.getOperands());
+    assert(idx < values.size());
+
+    // Updates the corresponding element.
+    values[idx] = adaptor.getValue();
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TypeRange{op.getType()}, values);
+    return success();
+  }
+};
+
 /// Sparse tensor storage conversion rule for returns.
 class SparseStorageReturnConverter
     : public OpConversionPattern<func::ReturnOp> {
@@ -54,24 +113,69 @@ class SparseStorageReturnConverter
   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Value, 8> flattened;
-    for (auto operand : adaptor.getOperands()) {
-      if (auto cast =
-              dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
-          cast && cast->getResultTypes()[0].isa<TupleType>())
-        // An unrealized_conversion_cast will be inserted by type converter to
-        // inter-mix the gap between 1:N conversion between tuple and types.
-        // In this case, take the operands in the cast and replace the tuple
-        // output with the flattened type array.
-        flattened.append(cast.getOperands().begin(), cast.getOperands().end());
-      else
-        flattened.push_back(operand);
-    }
+    flattenOperands(adaptor.getOperands(), flattened);
     // Create a return with the flattened value extracted from tuple.
     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
     return success();
   }
 };
 
+/// Sparse tensor storage conversion rule for calls.
+class SparseStorageCallConverter : public OpConversionPattern<func::CallOp> {
+public:
+  // The default CallOp converter can not handle 1:N type conversion properly
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // In case of:
+    //  tuple(a, b), f, tuple(c, d) = call @foo(...)
+    // ==>
+    //  a, b, f, c, d = call @foo(...)
+    //  cast(a, b)->tuple, f, cast(c,d)->tuple
+    SmallVector<Type, 8> finalRetTy;
+    if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
+      return failure();
+
+    // (1) Genereates new call with flattened return value.
+    SmallVector<Value, 8> flattened;
+    flattenOperands(adaptor.getOperands(), flattened);
+    auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
+                                                 finalRetTy, flattened);
+
+    // (2) Create cast operation for tuple returns.
+    SmallVector<Value, 4> castedRet;
+    // Tracks the offset of current return value (of the orignal call)
+    // relative to the new call (after tuple flattening);
+    unsigned retOffset = 0;
+    for (auto ret : op.getResults()) {
+      assert(retOffset < newCall.getNumResults());
+      auto tupleRet = ret.getType().dyn_cast<TupleType>();
+      if (tupleRet) {
+        auto tupleSize = tupleRet.size();
+        // NOTE: The range is computed under the assumption of non-recursive
+        // tuple type.
+        ValueRange tupleElem(iterator_range<ResultRange::iterator>(
+            newCall.result_begin() + retOffset,
+            newCall.result_begin() + retOffset + tupleSize));
+        auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+            loc, TypeRange({tupleRet}), tupleElem);
+        castedRet.push_back(castOp.getResult(0));
+        retOffset += tupleSize;
+      } else {
+        // If this not a tuple, simply add it into returned values.
+        castedRet.push_back(ret);
+        retOffset++;
+      }
+    }
+
+    assert(castedRet.size() == op.getNumResults());
+    rewriter.replaceOp(op, castedRet);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -91,6 +195,7 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
 /// to expand compounded sparse tensor tuples.
 void mlir::populateSparseTensorStorageExpansionPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<SparseStorageReturnConverter>(typeConverter,
-                                             patterns.getContext());
+  patterns.add<SparseStorageGetConverter, SparseStorageSetConverter,
+               SparseStorageReturnConverter, SparseStorageCallConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
index 445b234a2a8d2..87391978b0674 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s
 
 // CHECK-LABEL:  func @sparse_storage_expand(
 // CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
@@ -9,3 +9,41 @@ func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>
                                      -> tuple<memref<?xf64>, memref<?xf64>, f64> {
   return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
 }
+
+// CHECK-LABEL:  func @call_sparse_storage_expand(
+// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64) 
+// CHECK:          %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) 
+// CHECK:          return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
+func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
+                                          -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  %1 = call @sparse_storage_expand(%arg0) : (tuple<memref<?xf64>, memref<?xf64>, f64>) ->
+                                             tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// CHECK-LABEL:  func @sparse_storage_get(
+// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64) 
+// CHECK:          return %[[TMP_arg0]] : memref<?xf64>
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+  %0 = sparse_tensor.storage_get %arg0[0]
+       : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
+// CHECK-LABEL:  func @sparse_storage_set(
+// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64,
+// CHECK-SAME:     %[[TMP_arg3:.*]]: memref<?xf64>) 
+// CHECK:          return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
+                              %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  %0 = sparse_tensor.storage_set %arg0[0], %arg1
+       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+         tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}


        


More information about the Mlir-commits mailing list