[Mlir-commits] [mlir] a85ca6b - [MLIR][Shape] Simplify shape lowering
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 01:44:32 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T08:44:13Z
New Revision: a85ca6be2aa8cc5f5cbeefc9f4a1181b0fa8d4cc
URL: https://github.com/llvm/llvm-project/commit/a85ca6be2aa8cc5f5cbeefc9f4a1181b0fa8d4cc
DIFF: https://github.com/llvm/llvm-project/commit/a85ca6be2aa8cc5f5cbeefc9f4a1181b0fa8d4cc.diff
LOG: [MLIR][Shape] Simplify shape lowering
Differential Revision: https://reviews.llvm.org/D84161
Added:
Modified:
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index 7986aaaa6816..824db2685a77 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -172,39 +172,37 @@ LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
ShapeOfOp::Adaptor transformed(operands);
- auto tensorVal = transformed.arg();
- auto tensorTy = tensorVal.getType();
+ Value arg = transformed.arg();
+ Type argTy = arg.getType();
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass.
- if (tensorTy.isa<RankedTensorType>())
+ if (argTy.isa<RankedTensorType>())
return failure();
// Allocate stack memory.
auto loc = op.getLoc();
- auto rankVal = rewriter.create<mlir::RankOp>(loc, tensorVal);
- auto i64Ty = rewriter.getI64Type();
- auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
- auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
+ Value rank = rewriter.create<mlir::RankOp>(loc, arg);
+ Type i64Ty = rewriter.getI64Type();
+ Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
+ Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
// Copy shape extents to stack-allocated memory.
- auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
- auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
rewriter.create<scf::ForOp>(
- loc, zeroVal, rankVal, oneVal, llvm::None,
- [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
- auto dimVal = rewriter.create<DimOp>(loc, tensorVal, iVal);
- auto dimIntVal = rewriter.create<IndexCastOp>(loc, dimVal, i64Ty);
- rewriter.create<StoreOp>(loc, dimIntVal, memVal, ValueRange{iVal});
+ loc, zero, rank, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ Value dim = rewriter.create<DimOp>(loc, arg, iv);
+ Value dimInt = rewriter.create<IndexCastOp>(loc, dim, i64Ty);
+ rewriter.create<StoreOp>(loc, dimInt, mem, ValueRange{iv});
rewriter.create<scf::YieldOp>(loc);
});
// Load extents to tensor value.
- auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
- auto indexTy = rewriter.getIndexType();
- auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
- rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
- shapeTy);
+ Value extentTensorInt = rewriter.create<TensorLoadOp>(loc, mem);
+ rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), extentTensorInt,
+ op.getType());
return success();
}
More information about the Mlir-commits
mailing list