[Mlir-commits] [mlir] d3f5f33 - [mlir][scf] support 1:N type conversion for scf.for.
Peiming Liu
llvmlistbot at llvm.org
Fri Oct 21 14:12:01 PDT 2022
Author: Peiming Liu
Date: 2022-10-21T21:11:55Z
New Revision: d3f5f330671e718a0e28598c412d09e9a3b54273
URL: https://github.com/llvm/llvm-project/commit/d3f5f330671e718a0e28598c412d09e9a3b54273
DIFF: https://github.com/llvm/llvm-project/commit/d3f5f330671e718a0e28598c412d09e9a3b54273.diff
LOG: [mlir][scf] support 1:N type conversion for scf.for.
scf.for used to only support 1:1 type conversion, this patch add support for 1:N type conversion.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D136314
Added:
mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
Modified:
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index fb15fcb3579e2..c4c219617b782 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -15,58 +15,102 @@ using namespace mlir;
using namespace mlir::scf;
namespace {
+
+// Unpacks the single unrealized_conversion_cast using the list of inputs
+// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
+static void unpackUnrealizedConversionCast(Value v,
+ SmallVectorImpl<Value> &unpacked) {
+ if (auto cast =
+ dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
+ if (cast.getInputs().size() != 1) {
+ // 1 : N type conversion.
+ unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
+ return;
+ }
+ }
+ // 1 : 1 type conversion.
+ unpacked.push_back(v);
+}
+
class ConvertForOpTypes : public OpConversionPattern<ForOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type, 6> newResultTypes;
- for (auto type : op.getResultTypes()) {
- Type newType = typeConverter->convertType(type);
- if (!newType)
- return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
- newResultTypes.push_back(newType);
+ SmallVector<Type> newResultTypes;
+ SmallVector<unsigned> offsets;
+ offsets.push_back(0);
+ // Do the type conversion and record the offsets.
+ for (Type type : op.getResultTypes()) {
+ if (failed(typeConverter->convertTypes(type, newResultTypes)))
+ return rewriter.notifyMatchFailure(op, "could not convert result");
+ offsets.push_back(newResultTypes.size());
}
- // Clone the op without the regions and inline the regions from the old op.
+ // Create a empty new op and inline the regions from the old op.
//
// This is a little bit tricky. We have two concerns here:
//
// 1. We cannot update the op in place because the dialect conversion
// framework does not track type changes for ops updated in place, so it
// won't insert appropriate materializations on the changed result types.
- // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
- // clone the op.
+ // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
+ // to clone the op.
//
- // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
- // inefficient to recursively clone the regions, there is a correctness
- // issue: if we clone with the regions, then the dialect conversion
- // framework thinks that we just inserted all the cloned child ops. But what
- // we want is to "take" the child regions and let the dialect conversion
- // framework continue recursively into ops inside those regions (which are
- // already in its worklist; inlining them into the new op's regions doesn't
- // remove the child ops from the worklist).
- ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
- // Take the region from the old op and put it in the new op.
+ // 2. We need to resue the original region instead of cloning it, otherwise
+ // the dialect conversion framework thinks that we just inserted all the
+ // cloned child ops. But what we want is to "take" the child regions and let
+ // the dialect conversion framework continue recursively into ops inside
+ // those regions (which are already in its worklist; inlining them into the
+ // new op's regions doesn't remove the child ops from the worklist).
+
+ // convertRegionTypes already takes care of 1:N conversion.
+ if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
+ return failure();
+
+ // Unpacked the iteration arguments.
+ SmallVector<Value> flatArgs;
+ for (Value arg : adaptor.getInitArgs())
+ unpackUnrealizedConversionCast(arg, flatArgs);
+
+ // We can not do clone as the number of result types after conversion might
+ // be
diff erent.
+ ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
+ adaptor.getUpperBound(),
+ adaptor.getStep(), flatArgs);
+
+ // Reserve whatever attributes in the original op.
+ newOp->setAttrs(op->getAttrs());
+
+ // We do not need the empty block created by rewriter.
+ rewriter.eraseBlock(newOp.getBody(0));
+ // Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
- // Now, update all the types.
-
- // Convert the type of the entry block of the ForOp's body.
- if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
- *getTypeConverter()))) {
- return rewriter.notifyMatchFailure(op, "could not convert body types");
+ // Pack the return value.
+ SmallVector<Value, 6> packedRets;
+ for (unsigned i = 1, e = offsets.size(); i < e; i++) {
+ unsigned start = offsets[i - 1], end = offsets[i];
+ unsigned len = end - start;
+ ValueRange mappedValue = newOp.getResults().slice(start, len);
+ if (len != 1) {
+ // 1 : N type conversion.
+ Type origType = op.getResultTypes()[i - 1];
+ Value mat = typeConverter->materializeSourceConversion(
+ rewriter, op.getLoc(), origType, mappedValue);
+ if (!mat)
+ return rewriter.notifyMatchFailure(
+ op, "Failed to materialize 1:N type conversion");
+ packedRets.push_back(mat);
+ } else {
+ // 1 : 1 type conversion.
+ packedRets.push_back(mappedValue.front());
+ }
}
- // Change the clone to use the updated operands. We could have cloned with
- // a BlockAndValueMapping, but this seems a bit more direct.
- newOp->setOperands(adaptor.getOperands());
- // Update the result types to the new converted types.
- for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
- std::get<0>(t).setType(std::get<1>(t));
- rewriter.replaceOp(op, newOp.getResults());
+ rewriter.replaceOp(op, packedRets);
return success();
}
};
@@ -81,12 +125,12 @@ class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
ConversionPatternRewriter &rewriter) const override {
// TODO: Generalize this to any type conversion, not just 1:1.
//
- // We need to implement something more sophisticated here that tracks which
- // types convert to which other types and does the appropriate
+ // We need to implement something more sophisticated here that tracks
+ // which types convert to which other types and does the appropriate
// materialization logic.
// For example, it's possible that one result type converts to 0 types and
- // another to 2 types, so newResultTypes would at least be the right size to
- // not crash in the llvm::zip call below, but then we would set the the
+ // another to 2 types, so newResultTypes would at least be the right size
+ // to not crash in the llvm::zip call below, but then we would set the the
// wrong type on the SSA values! These edge cases are also why we cannot
// safely use the TypeConverter::convertTypes helper here.
SmallVector<Type, 6> newResultTypes;
@@ -125,7 +169,11 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
LogicalResult
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
+ SmallVector<Value> unpackedYield;
+ for (Value operand : adaptor.getOperands())
+ unpackUnrealizedConversionCast(operand, unpackedYield);
+
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
new file mode 100644
index 0000000000000..13765eb12d626
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+// CHECK-LABEL: func @for(
+// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
+// CHECK-SAME: %[[TMP_arg5:.*5]]: index,
+// CHECK-SAME: %[[TMP_arg6:.*6]]: index,
+// CHECK-SAME: %[[TMP_arg7:.*7]]: index
+// CHECK: %[[TMP_0:.*]]:5 = scf.for %[[TMP_arg8:.*]] = %[[TMP_arg5]] to %[[TMP_arg6]] step %[[TMP_arg7]] iter_args(
+// CHECK-SAME: %[[TMP_arg9:.*]] = %[[DIM_SIZE]],
+// CHECK-SAME: %[[TMP_arg10:.*]] = %[[MEM_SIZE]],
+// CHECK-SAME: %[[TMP_arg11:.*]] = %[[POINTER]],
+// CHECK-SAME: %[[TMP_arg12:.*]] = %[[INDICES]],
+// CHECK-SAME: %[[TMP_arg13:.*]] = %[[VALUE]])
+// CHECK: scf.yield %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]], %[[TMP_arg13]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: }
+// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2, %[[TMP_0]]#3, %[[TMP_0]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+func.func @for(%in: tensor<1024xf32, #SparseVector>,
+ %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
+ %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
+ -> tensor<1024xf32, #SparseVector> {
+ scf.yield %vin : tensor<1024xf32, #SparseVector>
+ }
+ return %1 : tensor<1024xf32, #SparseVector>
+}
More information about the Mlir-commits
mailing list