[Mlir-commits] [mlir] e5c8fc7 - [mlir][vector] canonicalize unmasked gather/scatter/compress/expand directly into l/s

Aart Bik llvmlistbot at llvm.org
Fri Mar 5 14:24:07 PST 2021


Author: Aart Bik
Date: 2021-03-05T14:23:50-08:00
New Revision: e5c8fc776fbd2c93e25f5749049ee31cf73a0a41

URL: https://github.com/llvm/llvm-project/commit/e5c8fc776fbd2c93e25f5749049ee31cf73a0a41
DIFF: https://github.com/llvm/llvm-project/commit/e5c8fc776fbd2c93e25f5749049ee31cf73a0a41.diff

LOG: [mlir][vector] canonicalize unmasked gather/scatter/compress/expand directly into l/s

With the new vector.load/store operations, there is no need to go through
unmasked transfer operations (which will canonicalized to l/s anyway).

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D98056

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Linalg/sparse_vector.mlir
    mlir/test/Dialect/Vector/vector-mem-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index ebe3250a09af..8945253644d8 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2764,8 +2764,8 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
                                 PatternRewriter &rewriter) const override {
     switch (get1DMaskFormat(load.mask())) {
     case MaskFormat::AllTrue:
-      rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-          load, load.getType(), load.base(), load.indices(), false);
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
+                                                  load.base(), load.indices());
       return success();
     case MaskFormat::AllFalse:
       rewriter.replaceOp(load, load.pass_thru());
@@ -2809,8 +2809,8 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
                                 PatternRewriter &rewriter) const override {
     switch (get1DMaskFormat(store.mask())) {
     case MaskFormat::AllTrue:
-      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-          store, store.valueToStore(), store.base(), store.indices(), false);
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          store, store.valueToStore(), store.base(), store.indices());
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(store);
@@ -2951,8 +2951,8 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
                                 PatternRewriter &rewriter) const override {
     switch (get1DMaskFormat(expand.mask())) {
     case MaskFormat::AllTrue:
-      rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-          expand, expand.getType(), expand.base(), expand.indices(), false);
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(
+          expand, expand.getType(), expand.base(), expand.indices());
       return success();
     case MaskFormat::AllFalse:
       rewriter.replaceOp(expand, expand.pass_thru());
@@ -2996,9 +2996,9 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
                                 PatternRewriter &rewriter) const override {
     switch (get1DMaskFormat(compress.mask())) {
     case MaskFormat::AllTrue:
-      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
           compress, compress.valueToStore(), compress.base(),
-          compress.indices(), false);
+          compress.indices());
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(compress);

diff  --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir
index beed6e8bc9aa..5409e4b80681 100644
--- a/mlir/test/Dialect/Linalg/sparse_vector.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir
@@ -35,10 +35,10 @@
 // CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
 // CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC1:         %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
 // CHECK-VEC1:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
-// CHECK-VEC1:         vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32>
+// CHECK-VEC1:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC1:       }
 // CHECK-VEC1:       return
 //
@@ -47,10 +47,10 @@
 // CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
 // CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC2:         %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
 // CHECK-VEC2:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
-// CHECK-VEC2:         vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32>
+// CHECK-VEC2:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
@@ -214,8 +214,8 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC1-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC1:         %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC1:         %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC1:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
 // CHECK-VEC1:         scf.yield %[[a]] : vector<16xf32>
@@ -229,8 +229,8 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC2-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC2:         %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC2:         %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC2:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
 // CHECK-VEC2:         scf.yield %[[a]] : vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index ab46a7863a55..bf2c0770bc3e 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -4,8 +4,7 @@
 // CHECK-SAME:                      %[[A0:.*]]: memref<?xf32>,
 // CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-DAG:       %[[C:.*]] = constant 0 : index
-// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT:      %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
 // CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c0 = constant 0 : index
@@ -19,8 +18,7 @@ func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16
 // CHECK-SAME:                      %[[A0:.*]]: memref<16xf32>,
 // CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-DAG:       %[[C:.*]] = constant 0 : index
-// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT:      %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
 // CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c0 = constant 0 : index
@@ -46,8 +44,7 @@ func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<1
 // CHECK-SAME:                      %[[A0:.*]]: memref<?xf32>,
 // CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-DAG:       %[[C:.*]] = constant 8 : index
-// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT:      %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
 // CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c8 = constant 8 : index
@@ -61,7 +58,7 @@ func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16
 // CHECK-SAME:                       %[[A0:.*]]: memref<16xf32>,
 // CHECK-SAME:                       %[[A1:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[C:.*]] = constant 0 : index
-// CHECK-NEXT:      vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT:      vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
 // CHECK-NEXT:      return
 func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
@@ -144,8 +141,7 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
 // CHECK-SAME:                  %[[A0:.*]]: memref<16xf32>,
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-DAG:       %[[C:.*]] = constant 0 : index
-// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT:      %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
 // CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c0 = constant 0 : index
@@ -171,7 +167,7 @@ func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf3
 // CHECK-SAME:                    %[[A0:.*]]: memref<16xf32>,
 // CHECK-SAME:                    %[[A1:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[C:.*]] = constant 0 : index
-// CHECK-NEXT:      vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT:      vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
 // CHECK-NEXT:      return
 func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
   %c0 = constant 0 : index


        


More information about the Mlir-commits mailing list