[Mlir-commits] [mlir] a85ca6b - [MLIR][Shape] Simplify shape lowering

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 01:44:32 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T08:44:13Z
New Revision: a85ca6be2aa8cc5f5cbeefc9f4a1181b0fa8d4cc

URL: https://github.com/llvm/llvm-project/commit/a85ca6be2aa8cc5f5cbeefc9f4a1181b0fa8d4cc
DIFF: https://github.com/llvm/llvm-project/commit/a85ca6be2aa8cc5f5cbeefc9f4a1181b0fa8d4cc.diff

LOG: [MLIR][Shape] Simplify shape lowering

Differential Revision: https://reviews.llvm.org/D84161

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index 7986aaaa6816..824db2685a77 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -172,39 +172,37 @@ LogicalResult
 ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
                                     ConversionPatternRewriter &rewriter) const {
   ShapeOfOp::Adaptor transformed(operands);
-  auto tensorVal = transformed.arg();
-  auto tensorTy = tensorVal.getType();
+  Value arg = transformed.arg();
+  Type argTy = arg.getType();
 
   // For ranked tensors `shape_of` lowers to `std` and the pattern can be
   // found in the corresponding pass.
-  if (tensorTy.isa<RankedTensorType>())
+  if (argTy.isa<RankedTensorType>())
     return failure();
 
   // Allocate stack memory.
   auto loc = op.getLoc();
-  auto rankVal = rewriter.create<mlir::RankOp>(loc, tensorVal);
-  auto i64Ty = rewriter.getI64Type();
-  auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
-  auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
+  Value rank = rewriter.create<mlir::RankOp>(loc, arg);
+  Type i64Ty = rewriter.getI64Type();
+  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
+  Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
 
   // Copy shape extents to stack-allocated memory.
-  auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
-  auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
   rewriter.create<scf::ForOp>(
-      loc, zeroVal, rankVal, oneVal, llvm::None,
-      [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
-        auto dimVal = rewriter.create<DimOp>(loc, tensorVal, iVal);
-        auto dimIntVal = rewriter.create<IndexCastOp>(loc, dimVal, i64Ty);
-        rewriter.create<StoreOp>(loc, dimIntVal, memVal, ValueRange{iVal});
+      loc, zero, rank, one, llvm::None,
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+        Value dim = rewriter.create<DimOp>(loc, arg, iv);
+        Value dimInt = rewriter.create<IndexCastOp>(loc, dim, i64Ty);
+        rewriter.create<StoreOp>(loc, dimInt, mem, ValueRange{iv});
         rewriter.create<scf::YieldOp>(loc);
       });
 
   // Load extents to tensor value.
-  auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
-  auto indexTy = rewriter.getIndexType();
-  auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
-  rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
-                                           shapeTy);
+  Value extentTensorInt = rewriter.create<TensorLoadOp>(loc, mem);
+  rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), extentTensorInt,
+                                           op.getType());
   return success();
 }
 


        


More information about the Mlir-commits mailing list