[Mlir-commits] [mlir] a586c55 - [mlir][vector] Add mask fold test for gather lowering

Jakub Kuderski llvmlistbot at llvm.org
Thu Mar 16 07:20:12 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-16T10:16:15-04:00
New Revision: a586c551000bcd874852ea1265f6dac4b3d894b3

URL: https://github.com/llvm/llvm-project/commit/a586c551000bcd874852ea1265f6dac4b3d894b3
DIFF: https://github.com/llvm/llvm-project/commit/a586c551000bcd874852ea1265f6dac4b3d894b3.diff

LOG: [mlir][vector] Add mask fold test for gather lowering

Check that `scf.if` checks are folded when the mask is all set / not
set.

This to address post-commit feedback for
https://reviews.llvm.org/D145942.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D146144

Added: 
    

Modified: 
    mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5afd2fc73a7cf..c98a71df26c2f 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s --test-vector-gather-lowering | FileCheck %s
+// RUN: mlir-opt %s --test-vector-gather-lowering --canonicalize | FileCheck %s --check-prefix=CANON
 
 // CHECK-LABEL: @gather_memref_1d
 // CHECK-SAME:    ([[BASE:%.+]]: memref<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
@@ -125,3 +126,28 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
   %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
  }
+
+// Check that all-set and no-set maskes get optimized out after canonicalization.
+
+// CANON-LABEL: @gather_tensor_1d_all_set
+// CANON-NOT:     scf.if
+// CANON:         tensor.extract
+// CANON:         tensor.extract
+// CANON:         [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : f32 into vector<2xf32>
+// CANON-NEXT:    return [[FINAL]] : vector<2xf32>
+func.func @gather_tensor_1d_all_set(%base: tensor<?xf32>, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %mask = arith.constant dense <true> : vector<2xi1>
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CANON-LABEL: @gather_tensor_1d_none_set
+// CANON-SAME:    ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[PASS:%.+]]: vector<2xf32>)
+// CANON-NEXT:    return [[PASS]] : vector<2xf32>
+func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %mask = arith.constant dense <false> : vector<2xi1>
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}


        


More information about the Mlir-commits mailing list