[Mlir-commits] [mlir] Make createReadOrMaskedRead a utility (PR #89119)

Lubomir Litchev llvmlistbot at llvm.org
Wed Apr 17 14:21:59 PDT 2024


https://github.com/LLITCHEV updated https://github.com/llvm/llvm-project/pull/89119

>From ba6d1ecf953172b41a1d3f8a35a30b7df97a67e7 Mon Sep 17 00:00:00 2001
From: Lubo Litchev <lubol at google.com>
Date: Wed, 17 Apr 2024 18:40:54 +0000
Subject: [PATCH 1/3] Make createReadOrMaskedRead a utility

Made the createReadOrMaskedRead a utility function - to be accessible
outside of the CU. Needed by the IREE new TopK implementation.
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  6 +++
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 29 ++++++++++++++
 .../Linalg/Transforms/Vectorization.cpp       | 40 -------------------
 3 files changed, 35 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..f4c56b671e9d7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,12 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
+/// Create a TransferReadOp from `source` with static shape `readShape`. If the
+/// vector type for the read is not the same as the type of `source`, then a
+/// mask is created on the read.
+Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
+                                    Value source, ArrayRef<int64_t> readShape,
+                                    Value padValue);
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a17bc8e4cd318f..b32ebfc380fcfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1593,3 +1593,32 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
       patterns.getContext(), benefit);
 }
+
+Value mlir::linalg::createReadOrMaskedRead(OpBuilder &builder, Location loc,
+                                    Value source, ArrayRef<int64_t> readShape,
+                                    Value padValue) {
+  assert(llvm::none_of(readShape,
+                       [](int64_t s) { return s == ShapedType::kDynamic; }));
+  auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
+  assert(sourceShape.size() == readShape.size());
+  auto maskType = VectorType::get(readShape, builder.getI1Type());
+  auto vectorType = VectorType::get(readShape, padValue.getType());
+  int64_t readRank = readShape.size();
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
+  if (llvm::equal(readShape, sourceShape)) {
+    return transferReadOp;
+  }
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(builder, loc, source);
+  Value mask =
+      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  return mlir::vector::maskOperation(builder, transferReadOp, mask)
+      ->getResult(0);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df61381432921b..e2ca5e14377286 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1410,46 +1410,6 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
 }
 
-/// Create a TransferReadOp from `source` with static shape `readShape`. If the
-/// vector type for the read is not the same as the type of `source`, then a
-/// mask is created on the read.  If `doMasking` parameter is set to false we
-/// update the `inBounds` attribute instead of masking.
-static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
-                                    Value source, ArrayRef<int64_t> readShape,
-                                    Value padValue, bool doMasking = true) {
-  assert(llvm::none_of(readShape,
-                       [](int64_t s) { return s == ShapedType::kDynamic; }));
-  auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
-  assert(sourceShape.size() == readShape.size());
-  auto maskType = VectorType::get(readShape, builder.getI1Type());
-  auto vectorType = VectorType::get(readShape, padValue.getType());
-  int64_t readRank = readShape.size();
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  SmallVector<bool> inBoundsVal(readRank, true);
-  if (!doMasking) {
-    // Update the inBounds attribute.
-    for (unsigned i = 0; i < readRank; i++)
-      inBoundsVal[i] = sourceShape[i] == readShape[i];
-  }
-  auto transferReadOp = builder.create<vector::TransferReadOp>(
-      loc,
-      /*vectorType=*/vectorType,
-      /*source=*/source,
-      /*indices=*/SmallVector<Value>(readRank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/inBoundsVal);
-
-  if (llvm::equal(readShape, sourceShape) || !doMasking) {
-    return transferReadOp;
-  }
-  SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
-  Value mask =
-      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
-  return mlir::vector::maskOperation(builder, transferReadOp, mask)
-      ->getResult(0);
-}
-
 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
 /// create an empty destination tensor and create a TransferWriteOp from the
 /// input to the empty tensor. If the destination shape is not the same as the

>From 10823adeb62c0c42b63d9e66706ead84bb0fb534 Mon Sep 17 00:00:00 2001
From: Lubo Litchev <lubol at google.com>
Date: Wed, 17 Apr 2024 18:48:13 +0000
Subject: [PATCH 2/3] Formatting fixes.

---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 5 ++---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp        | 5 +++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f4c56b671e9d7e..a8175c98776775 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1619,9 +1619,8 @@ void populateSplitReductionPattern(
 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
 /// vector type for the read is not the same as the type of `source`, then a
 /// mask is created on the read.
-Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
-                                    Value source, ArrayRef<int64_t> readShape,
-                                    Value padValue);
+Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
+                             ArrayRef<int64_t> readShape, Value padValue);
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index b32ebfc380fcfb..ebc7933f7fd35b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1595,8 +1595,9 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
 }
 
 Value mlir::linalg::createReadOrMaskedRead(OpBuilder &builder, Location loc,
-                                    Value source, ArrayRef<int64_t> readShape,
-                                    Value padValue) {
+                                           Value source,
+                                           ArrayRef<int64_t> readShape,
+                                           Value padValue) {
   assert(llvm::none_of(readShape,
                        [](int64_t s) { return s == ShapedType::kDynamic; }));
   auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();

>From 0898983570b7458ad8ed22065d37a5d9592954fd Mon Sep 17 00:00:00 2001
From: Lubo Litchev <lubol at google.com>
Date: Wed, 17 Apr 2024 21:21:10 +0000
Subject: [PATCH 3/3] Merge of the latest.

---
 .../mlir/Dialect/Linalg/Transforms/Transforms.h     |  3 ++-
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp   | 13 ++++++++++---
 2 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a8175c98776775..db6b23c5894941 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1620,7 +1620,8 @@ void populateSplitReductionPattern(
 /// vector type for the read is not the same as the type of `source`, then a
 /// mask is created on the read.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
-                             ArrayRef<int64_t> readShape, Value padValue);
+                             ArrayRef<int64_t> readShape, Value padValue,
+                             bool doMasking = true);
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ebc7933f7fd35b..b4d70c464e1268 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1597,7 +1597,7 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
 Value mlir::linalg::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                            Value source,
                                            ArrayRef<int64_t> readShape,
-                                           Value padValue) {
+                                           Value padValue, bool doMasking) {
   assert(llvm::none_of(readShape,
                        [](int64_t s) { return s == ShapedType::kDynamic; }));
   auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
@@ -1606,14 +1606,21 @@ Value mlir::linalg::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto vectorType = VectorType::get(readShape, padValue.getType());
   int64_t readRank = readShape.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  SmallVector<bool> inBoundsVal(readRank, true);
+  if (!doMasking) {
+    // Update the inBounds attribute.
+    for (unsigned i = 0; i < readRank; i++)
+      inBoundsVal[i] = sourceShape[i] == readShape[i];
+  }
   auto transferReadOp = builder.create<vector::TransferReadOp>(
       loc,
       /*vectorType=*/vectorType,
       /*source=*/source,
       /*indices=*/SmallVector<Value>(readRank, zero),
       /*padding=*/padValue,
-      /*inBounds=*/SmallVector<bool>(readRank, true));
-  if (llvm::equal(readShape, sourceShape)) {
+      /*inBounds=*/inBoundsVal);
+
+  if (llvm::equal(readShape, sourceShape) || !doMasking) {
     return transferReadOp;
   }
   SmallVector<OpFoldResult> mixedSourceDims =



More information about the Mlir-commits mailing list