[Mlir-commits] [mlir] 42e5f42 - [mlir] Support complex numbers in Linalg promotion
Tres Popp
llvmlistbot at llvm.org
Thu Apr 29 02:59:13 PDT 2021
Author: Tres Popp
Date: 2021-04-29T11:58:57+02:00
New Revision: 42e5f42215c098face7f835f1a5a223409b85f69
URL: https://github.com/llvm/llvm-project/commit/42e5f42215c098face7f835f1a5a223409b85f69
DIFF: https://github.com/llvm/llvm-project/commit/42e5f42215c098face7f835f1a5a223409b85f69.diff
LOG: [mlir] Support complex numbers in Linalg promotion
FillOp allows complex ops, and filling a properly sized buffer with
a default zero complex number is implemented.
Differential Revision: https://reviews.llvm.org/D99939
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index b158cec529d99..8e323bd85c7bb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -182,7 +182,8 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
def FillOp : LinalgStructured_Op<"fill", []> {
let arguments = (ins AnyShaped:$output,
- AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
+ AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger,
+ AnyVector]>:$value);
let results = (outs Optional<AnyRankedTensor>:$result);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = structuredOpsDecls # [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d07d4f2ec8773..52f7425cbfefd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -215,9 +215,10 @@ void interchange(PatternRewriter &rewriter, LinalgOp op,
/// smallest constant value for the size of the buffer needed for each
/// dimension. If that is not possible, contains the dynamic size of the
/// subview. The call back should return the buffer to use.
-using AllocBufferCallbackFn = std::function<Optional<Value>(
- OpBuilder &b, memref::SubViewOp subView,
- ArrayRef<Value> boundingSubViewSize, OperationFolder *folder)>;
+using AllocBufferCallbackFn =
+ std::function<Optional<Value>(OpBuilder &b, memref::SubViewOp subView,
+ ArrayRef<Value> boundingSubViewSize,
+ DataLayout &layout, OperationFolder *folder)>;
/// Callback function type used to deallocate the buffers used to hold the
/// promoted subview.
@@ -315,6 +316,7 @@ struct PromotionInfo {
Optional<PromotionInfo>
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
AllocBufferCallbackFn allocationFn,
+ DataLayout &layout,
OperationFolder *folder = nullptr);
/// Promotes the `subViews` into a new buffer allocated at the insertion point
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 7c8aeb04e3d09..30da28b6e1e32 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -12,6 +12,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
@@ -49,10 +50,10 @@ using folded_memref_view = FoldedValueBuilder<memref::ViewOp>;
/// the size needed, otherwise try to allocate a static bounding box.
static Value allocBuffer(const LinalgPromotionOptions &options,
Type elementType, Value size, bool dynamicBuffers,
- OperationFolder *folder,
+ DataLayout &layout, OperationFolder *folder,
Optional<unsigned> alignment = None) {
auto *ctx = size.getContext();
- auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+ auto width = layout.getTypeSize(elementType);
IntegerAttr alignment_attr;
if (alignment.hasValue())
alignment_attr =
@@ -88,7 +89,7 @@ defaultAllocBufferCallBack(const LinalgPromotionOptions &options,
OpBuilder &builder, memref::SubViewOp subView,
ArrayRef<Value> boundingSubViewSize,
bool dynamicBuffers, Optional<unsigned> alignment,
- OperationFolder *folder) {
+ DataLayout &layout, OperationFolder *folder) {
ShapedType viewType = subView.getType();
int64_t rank = viewType.getRank();
(void)rank;
@@ -100,7 +101,7 @@ defaultAllocBufferCallBack(const LinalgPromotionOptions &options,
for (auto size : llvm::enumerate(boundingSubViewSize))
allocSize = folded_std_muli(folder, allocSize, size.value());
Value buffer = allocBuffer(options, viewType.getElementType(), allocSize,
- dynamicBuffers, folder, alignment);
+ dynamicBuffers, layout, folder, alignment);
SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
ShapedType::kDynamicSize);
Value view = folded_memref_view(
@@ -170,15 +171,16 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
}
}
- allocationFn = (options.allocationFn
- ? *(options.allocationFn)
- : [&](OpBuilder &builder, memref::SubViewOp subViewOp,
- ArrayRef<Value> boundingSubViewSize,
- OperationFolder *folder) -> Optional<Value> {
- return defaultAllocBufferCallBack(options, builder, subViewOp,
- boundingSubViewSize, dynamicBuffers,
- alignment, folder);
- });
+ allocationFn =
+ (options.allocationFn
+ ? *(options.allocationFn)
+ : [&](OpBuilder &builder, memref::SubViewOp subViewOp,
+ ArrayRef<Value> boundingSubViewSize, DataLayout &layout,
+ OperationFolder *folder) -> Optional<Value> {
+ return defaultAllocBufferCallBack(options, builder, subViewOp,
+ boundingSubViewSize, dynamicBuffers,
+ alignment, layout, folder);
+ });
deallocationFn =
(options.deallocationFn
? *(options.deallocationFn)
@@ -213,7 +215,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
// by a partial `copy` op.
Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
OpBuilder &b, Location loc, memref::SubViewOp subView,
- AllocBufferCallbackFn allocationFn, OperationFolder *folder) {
+ AllocBufferCallbackFn allocationFn, DataLayout &layout,
+ OperationFolder *folder) {
ScopedContext scopedContext(b, loc);
auto viewType = subView.getType();
auto rank = viewType.getRank();
@@ -236,7 +239,8 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
// If a callback is not specified, then use the default implementation for
// allocating the promoted buffer.
- Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, folder);
+ Optional<Value> fullLocalView =
+ allocationFn(b, subView, fullSizes, layout, folder);
if (!fullLocalView)
return {};
SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
@@ -248,7 +252,7 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
static Optional<MapVector<unsigned, PromotionInfo>>
promoteSubViews(OpBuilder &b, Location loc,
- LinalgOpInstancePromotionOptions options,
+ LinalgOpInstancePromotionOptions options, DataLayout &layout,
OperationFolder *folder) {
if (options.subViews.empty())
return {};
@@ -260,7 +264,7 @@ promoteSubViews(OpBuilder &b, Location loc,
memref::SubViewOp subView =
cast<memref::SubViewOp>(v.second.getDefiningOp());
Optional<PromotionInfo> promotionInfo = promoteSubviewAsNewBuffer(
- b, loc, subView, options.allocationFn, folder);
+ b, loc, subView, options.allocationFn, layout, folder);
if (!promotionInfo)
return {};
promotionInfoMap[v.first] = *promotionInfo;
@@ -269,11 +273,21 @@ promoteSubViews(OpBuilder &b, Location loc,
if (!options.useFullTileBuffers[v.second])
continue;
Value fillVal;
- if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
+ if (auto t = subView.getType().getElementType().dyn_cast<FloatType>()) {
fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0));
- else if (auto t =
- subView.getType().getElementType().dyn_cast<IntegerType>())
+ } else if (auto t =
+ subView.getType().getElementType().dyn_cast<IntegerType>()) {
fillVal = folded_std_constant_int(folder, 0, t);
+ } else if (auto t =
+ subView.getType().getElementType().dyn_cast<ComplexType>()) {
+ if (auto et = t.getElementType().dyn_cast<FloatType>())
+ fillVal = folded_std_constant(folder, FloatAttr::get(et, 0.0));
+ else if (auto et = t.getElementType().cast<IntegerType>())
+ fillVal = folded_std_constant_int(folder, 0, et);
+ fillVal = b.create<complex::CreateOp>(loc, t, fillVal, fillVal);
+ } else {
+ return {};
+ }
linalg_fill(promotionInfo->fullLocalView, fillVal);
}
@@ -292,7 +306,7 @@ promoteSubViews(OpBuilder &b, Location loc,
static Optional<LinalgOp>
promoteSubViews(OpBuilder &b, LinalgOp op,
- LinalgOpInstancePromotionOptions options,
+ LinalgOpInstancePromotionOptions options, DataLayout &layout,
OperationFolder *folder) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
@@ -304,7 +318,8 @@ promoteSubViews(OpBuilder &b, LinalgOp op,
// 1. Promote the specified views and use them in the new op.
auto loc = op.getLoc();
- auto promotedBuffersAndViews = promoteSubViews(b, loc, options, folder);
+ auto promotedBuffersAndViews =
+ promoteSubViews(b, loc, options, layout, folder);
if (!promotedBuffersAndViews ||
promotedBuffersAndViews->size() != options.subViews.size())
return {};
@@ -376,8 +391,8 @@ Optional<LinalgOp> mlir::linalg::promoteSubViews(OpBuilder &b,
LinalgPromotionOptions options,
OperationFolder *folder) {
LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options);
- return ::promoteSubViews(
- b, linalgOp, LinalgOpInstancePromotionOptions(linalgOp, options), folder);
+ auto layout = DataLayout::closest(linalgOp);
+ return ::promoteSubViews(b, linalgOp, linalgOptions, layout, folder);
}
namespace {
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index e4af3b52f31a3..f3092efbc5805 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -345,6 +345,29 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK: linalg.fill(%[[v0]], %[[cf]]) : memref<?x?xf32>, f32
+func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, offset: ?, strides: [?, 1]>) {
+ %c2000 = constant 2000 : index
+ %c4000 = constant 4000 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %cf = constant 1.0 : f32
+ %cc = complex.create %cf, %cf : complex<f32>
+ %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] :
+ memref<?x?xcomplex<f32>, offset: ?, strides: [?, 1]> to memref<?x?xcomplex<f32>, offset: ?, strides: [?, ?]>
+ linalg.fill(%3, %cc) { __internal_linalg_transform__ = "_promote_views_aligned_"}
+ : memref<?x?xcomplex<f32>, offset: ?, strides: [?, ?]>, complex<f32>
+ return
+}
+// CHECK-LABEL: func @aligned_promote_fill_complex
+// CHECK: %[[cc:.*]] = complex.create {{.*}} : complex<f32>
+// CHECK: %[[s0:.*]] = memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xcomplex<f32>, #map{{.*}}> to memref<?x?xcomplex<f32>, #map{{.*}}>
+// CHECK: %[[a0:.*]] = memref.alloc({{%.*}}) {alignment = 32 : i64} : memref<?xi8>
+// CHECK: %[[v0:.*]] = memref.view %[[a0]][{{.*}}][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xcomplex<f32>>
+// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, #[[$STRIDED_2D_u_1]]>
+// CHECK: linalg.fill(%[[v0]], {{%.*}}) : memref<?x?xcomplex<f32>>, complex<f32>
+// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xcomplex<f32>, #map{{.*}}>, memref<?x?xcomplex<f32>, #map{{.*}}>
+// CHECK: linalg.fill(%[[v0]], %[[cc]]) : memref<?x?xcomplex<f32>>, complex<f32>
+
func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>) {
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 178de38039aa1..73282748e0c7c 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -276,6 +276,7 @@ static void fillL1TilingAndMatmulToVectorPatterns(
// Allocation call back
static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView,
ArrayRef<Value> boundingSubViewSize,
+ DataLayout &layout,
OperationFolder *folder) {
SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
return b
More information about the Mlir-commits
mailing list