[Mlir-commits] [mlir] [mlir][MemRef] Add subview folding pattern for vector.maskedload (PR #71380)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 6 04:28:49 PST 2023


https://github.com/tyb0807 updated https://github.com/llvm/llvm-project/pull/71380

>From e090a67ccbb1568935c342599680212f504ddfd2 Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Mon, 6 Nov 2023 11:28:55 +0000
Subject: [PATCH] [mlir][MemRef] Add subview folding pattern for
 vector.maskedload

This is required for fixing openxla/iree#15031
---
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  |  8 ++++++++
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 19 ++++++++++++++++++-
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 9899c357daeeeb4..043e8fbcdd2f6fb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
 
 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
 
+static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
+
 static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.getSource();
 }
@@ -415,6 +417,11 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         rewriter.replaceOpWithNewOp<vector::LoadOp>(
             op, op.getType(), subViewOp.getSource(), sourceIndices);
       })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), subViewOp.getSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Case([&](vector::TransferReadOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -687,6 +694,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                LoadOpOfSubViewOpFolder<memref::LoadOp>,
                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
                LoadOpOfSubViewOpFolder<vector::LoadOp>,
+               LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 47ec0e67d99cb3e..3f11e22749bb16d 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -654,7 +654,7 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 // -----
 
 func.func @fold_vector_load(
-  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {  
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
   return %1 : vector<12x32xf32>
@@ -665,3 +665,20 @@ func.func @fold_vector_load(
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 //      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>
+
+// -----
+
+func.func @fold_vector_maskedload(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+  return %1 : vector<32xf32>
+}
+
+//      CHECK: func @fold_vector_maskedload
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
+//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>



More information about the Mlir-commits mailing list