[Mlir-commits] [mlir] 928b5b0 - [mlir][sparse] add conversion rules for storage_get/set/callOp
Peiming Liu
llvmlistbot at llvm.org
Fri Sep 2 11:28:04 PDT 2022
Author: Peiming Liu
Date: 2022-09-02T18:27:54Z
New Revision: 928b5b06f94dd95350bf9df298845bdecf60fa40
URL: https://github.com/llvm/llvm-project/commit/928b5b06f94dd95350bf9df298845bdecf60fa40
DIFF: https://github.com/llvm/llvm-project/commit/928b5b06f94dd95350bf9df298845bdecf60fa40.diff
LOG: [mlir][sparse] add conversion rules for storage_get/set/callOp
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133175
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 538041a1b36a6..9057921e87c68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -217,10 +217,12 @@ struct SparseTensorStorageExpansionPass
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
+ // We generate UnrealizedConversionCastOp to intermix tuples and a
+ // list of types.
+ target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
- populateCallOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorStorageExpansionPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
index c1305eba79009..31370e32cb063 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
@@ -41,10 +41,69 @@ convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
return llvm::None;
}
+/// Flatten a list of operands that may contain tuples.
+static void flattenOperands(ValueRange operands,
+ SmallVectorImpl<Value> &flattened) {
+ // In case of
+ // tuple<a, b>, c, tuple<d, e>
+ // ==>
+ // a, b, c, d, e
+ for (auto operand : operands) {
+ if (auto cast =
+ dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+ cast && cast->getResultTypes()[0].isa<TupleType>())
+ // An unrealized_conversion_cast will be inserted by type converter to
+ // inter-mix the gap between 1:N conversion between tuple and types.
+ // In this case, take the operands in the cast and replace the tuple
+ // output with the flattened type array.
+ flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+ else
+ flattened.push_back(operand);
+ }
+}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
+/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
+class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto castOp =
+ cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
+ uint64_t idx = op.getIdx().getZExtValue();
+ assert(idx < castOp.getOperands().size());
+
+ rewriter.replaceOp(op, castOp.getOperand(idx));
+ return success();
+ }
+};
+
+/// Sparse tensor storage conversion rule for sparse_tensor::storage_set.
+class SparseStorageSetConverter : public OpConversionPattern<StorageSetOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(StorageSetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto castOp =
+ cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
+ uint64_t idx = op.getIdx().getZExtValue();
+
+ SmallVector<Value, 8> values(castOp.getOperands());
+ assert(idx < values.size());
+
+ // Updates the corresponding element.
+ values[idx] = adaptor.getValue();
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TypeRange{op.getType()}, values);
+ return success();
+ }
+};
+
/// Sparse tensor storage conversion rule for returns.
class SparseStorageReturnConverter
: public OpConversionPattern<func::ReturnOp> {
@@ -54,24 +113,69 @@ class SparseStorageReturnConverter
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value, 8> flattened;
- for (auto operand : adaptor.getOperands()) {
- if (auto cast =
- dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
- cast && cast->getResultTypes()[0].isa<TupleType>())
- // An unrealized_conversion_cast will be inserted by type converter to
- // inter-mix the gap between 1:N conversion between tuple and types.
- // In this case, take the operands in the cast and replace the tuple
- // output with the flattened type array.
- flattened.append(cast.getOperands().begin(), cast.getOperands().end());
- else
- flattened.push_back(operand);
- }
+ flattenOperands(adaptor.getOperands(), flattened);
// Create a return with the flattened value extracted from tuple.
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
return success();
}
};
+/// Sparse tensor storage conversion rule for calls.
+class SparseStorageCallConverter : public OpConversionPattern<func::CallOp> {
+public:
+ // The default CallOp converter can not handle 1:N type conversion properly
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // In case of:
+ // tuple(a, b), f, tuple(c, d) = call @foo(...)
+ // ==>
+ // a, b, f, c, d = call @foo(...)
+ // cast(a, b)->tuple, f, cast(c,d)->tuple
+ SmallVector<Type, 8> finalRetTy;
+ if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
+ return failure();
+
+ // (1) Genereates new call with flattened return value.
+ SmallVector<Value, 8> flattened;
+ flattenOperands(adaptor.getOperands(), flattened);
+ auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
+ finalRetTy, flattened);
+
+ // (2) Create cast operation for tuple returns.
+ SmallVector<Value, 4> castedRet;
+ // Tracks the offset of current return value (of the orignal call)
+ // relative to the new call (after tuple flattening);
+ unsigned retOffset = 0;
+ for (auto ret : op.getResults()) {
+ assert(retOffset < newCall.getNumResults());
+ auto tupleRet = ret.getType().dyn_cast<TupleType>();
+ if (tupleRet) {
+ auto tupleSize = tupleRet.size();
+ // NOTE: The range is computed under the assumption of non-recursive
+ // tuple type.
+ ValueRange tupleElem(iterator_range<ResultRange::iterator>(
+ newCall.result_begin() + retOffset,
+ newCall.result_begin() + retOffset + tupleSize));
+ auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+ loc, TypeRange({tupleRet}), tupleElem);
+ castedRet.push_back(castOp.getResult(0));
+ retOffset += tupleSize;
+ } else {
+ // If this not a tuple, simply add it into returned values.
+ castedRet.push_back(ret);
+ retOffset++;
+ }
+ }
+
+ assert(castedRet.size() == op.getNumResults());
+ rewriter.replaceOp(op, castedRet);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -91,6 +195,7 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
/// to expand compounded sparse tensor tuples.
void mlir::populateSparseTensorStorageExpansionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
- patterns.add<SparseStorageReturnConverter>(typeConverter,
- patterns.getContext());
+ patterns.add<SparseStorageGetConverter, SparseStorageSetConverter,
+ SparseStorageReturnConverter, SparseStorageCallConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
index 445b234a2a8d2..87391978b0674 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s
// CHECK-LABEL: func @sparse_storage_expand(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
@@ -9,3 +9,41 @@ func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>
-> tuple<memref<?xf64>, memref<?xf64>, f64> {
return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
}
+
+// CHECK-LABEL: func @call_sparse_storage_expand(
+// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
+// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
+// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
+func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
+ -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+ %1 = call @sparse_storage_expand(%arg0) : (tuple<memref<?xf64>, memref<?xf64>, f64>) ->
+ tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// CHECK-LABEL: func @sparse_storage_get(
+// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
+// CHECK: return %[[TMP_arg0]] : memref<?xf64>
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+ %0 = sparse_tensor.storage_get %arg0[0]
+ : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
+// CHECK-LABEL: func @sparse_storage_set(
+// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg2:.*]]: f64,
+// CHECK-SAME: %[[TMP_arg3:.*]]: memref<?xf64>)
+// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
+ %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+ %0 = sparse_tensor.storage_set %arg0[0], %arg1
+ : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+ tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
More information about the Mlir-commits
mailing list