[Mlir-commits] [mlir] c3839c0 - CombineContractBroadcast should not create dims unused in LHS+RHS

Benoit Jacob llvmlistbot at llvm.org
Mon Jul 4 09:52:52 PDT 2022


Author: Benoit Jacob
Date: 2022-07-04T16:52:35Z
New Revision: c3839c0b46a902f96f9395ec15d89d0cb21c73b3

URL: https://github.com/llvm/llvm-project/commit/c3839c0b46a902f96f9395ec15d89d0cb21c73b3
DIFF: https://github.com/llvm/llvm-project/commit/c3839c0b46a902f96f9395ec15d89d0cb21c73b3.diff

LOG: CombineContractBroadcast should not create dims unused in LHS+RHS

Differential Revision: https://reviews.llvm.org/D129087

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index de94f43708fad..952fbde9b487f 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -18,6 +18,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 namespace llvm {
 class SmallBitVector;
@@ -584,6 +585,11 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
   map.print(os);
   return os;
 }
+
+// Return a bitvector where each bit set indicates a dimension that is not used
+// by any of the maps in the input array `maps`.
+llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
+
 } // namespace mlir
 
 namespace llvm {

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6c1ba2161b83c..a3874795571d5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -687,6 +687,9 @@ static LogicalResult verifyOutputShape(
     MLIRContext *ctx = op.getContext();
     AffineMap lhsMap = op.getIndexingMaps()[0];
     AffineMap rhsMap = op.getIndexingMaps()[1];
+    if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+      return op.emitOpError(
+          "expected all dimensions to be either a LHS or a RHS dimension");
     SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
     for (auto pair :
          {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
@@ -699,8 +702,8 @@ static LogicalResult verifyOutputShape(
       }
     }
     if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
-      return op.emitOpError("expected all input dimensions to be used by "
-                            "either the LHS or the RHS");
+      return op.emitOpError("expected all dimensions to get an extent as "
+                            "either a LHS or a RHS dimension");
 
     AffineMap resMap = op.getIndexingMaps()[2];
     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 80e2a0a3f7ea3..bb8cc2bfae396 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -32,7 +32,6 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -1155,21 +1154,14 @@ struct CombineContractBroadcast
 
     // Determine which dims are usused, now that the maps have been composed
     // with the broadcast maps.
-    unsigned numDims = maps[0].getNumDims();
-    llvm::SmallBitVector unusedDims(numDims, true);
-    for (const auto &m : maps) {
-      for (unsigned i = 0; i < numDims; ++i) {
-        if (m.isFunctionOfDim(i))
-          unusedDims.reset(i);
-      }
-    }
+    llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
     // Compress unused dims.
     for (auto &m : maps)
-      m = compressDims(m, unusedDims);
+      m = compressDims(m, unusedDimsBitVector);
     // Compute the combined iterators.
     SmallVector<Attribute, 4> iterators;
-    for (unsigned i = 0; i < numDims; ++i) {
-      if (!unusedDims.test(i))
+    for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
+      if (!unusedDimsBitVector.test(i))
         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
     }
     // Check that compressing unused dims isn't removing all reduction
@@ -1179,7 +1171,10 @@ struct CombineContractBroadcast
     // a reduction iterator.
     if (!llvm::any_of(iterators, isReductionIterator))
       return failure();
-
+    // If the compressed maps have a dimension that is not used by either LHS or
+    // RHS then the ContractionOp verifier would fail.
+    if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
+      return failure();
     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
         contractOp, lhs, rhs, contractOp.getAcc(),
         rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index c93b4c28d769a..a8de9579edae6 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -560,12 +560,7 @@ AffineMap mlir::compressDims(AffineMap map,
 }
 
 AffineMap mlir::compressUnusedDims(AffineMap map) {
-  llvm::SmallBitVector unusedDims(map.getNumDims(), true);
-  map.walkExprs([&](AffineExpr expr) {
-    if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
-      unusedDims.reset(dimExpr.getPosition());
-  });
-  return compressDims(map, unusedDims);
+  return compressDims(map, getUnusedDimsBitVector({map}));
 }
 
 static SmallVector<AffineMap>
@@ -722,6 +717,18 @@ AffineMap mlir::getProjectedMap(AffineMap map,
   return compressUnusedSymbols(compressDims(map, unusedDims));
 }
 
+llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
+  unsigned numDims = maps[0].getNumDims();
+  llvm::SmallBitVector numDimsBitVector(numDims, true);
+  for (const auto &m : maps) {
+    for (unsigned i = 0; i < numDims; ++i) {
+      if (m.isFunctionOfDim(i))
+        numDimsBitVector.reset(i);
+    }
+  }
+  return numDimsBitVector;
+}
+
 //===----------------------------------------------------------------------===//
 // MutableAffineMap.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9a2afb9cc2774..87e5f9443807c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -875,7 +875,7 @@ func.func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: ve
 // -----
 
 func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
-// expected-error at +1 {{'vector.contract' op expected all input dimensions to be used by either the LHS or the RHS}}
+// expected-error at +1 {{'vector.contract' op expected all dimensions to be either a LHS or a RHS dimension}}
   %result = vector.contract {
     indexing_maps = [
       affine_map<(d0, d1, d2) -> (d0, d2)>,

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 8549327326b80..ade539e278226 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -159,6 +159,10 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+
 // CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
 //  CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
@@ -178,6 +182,37 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
     return %result : vector<8x8xi32>
 }
 
+// -----
+
+// Test that CombineContractBroadcast is not combining this case, as that would
+// result in a dimension being unused in the LHS and RHS maps, which is illegal.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
+//  CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
+//  CHECK: vector.contract
+//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
+//  CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
+
+func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
+  %1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32>
+  %result = vector.contract {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["reduction", "parallel", "reduction"],
+      kind = #vector.kind<add>
+  } %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
+  return  %result : vector<1xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // Reorder casting ops and vector ops. The casting ops have almost identical
 // pattern, so only arith.extsi op is tested.


        


More information about the Mlir-commits mailing list