[Mlir-commits] [mlir] [MLIR][Linalg] Scalable Vectorization of Reduction (PR #97788)

Zhaoshi Zheng llvmlistbot at llvm.org
Thu Jul 11 10:38:38 PDT 2024


================
@@ -1947,13 +1956,30 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (inputVectorSizes.empty())
     return success();
 
+  auto linalgOp = dyn_cast<LinalgOp>(op);
+  if (linalgOp && isLinalgReduction(linalgOp)) {
+    LDBG("Checking reduce op dims for scalable vectorization\n");
+    auto iteratorTypes = linalgOp.getIteratorTypesArray();
+    assert(iteratorTypes.size() == inputScalableVecDims.size() &&
+           "Number of iterator types and input scalable dims mismatch");
+    // For now, only support scalable vectorization of a reduction on the
+    // trailing dim.
+    for (size_t i = 0; i < inputScalableVecDims.size() - 1; ++i) {
+      if (inputScalableVecDims[i] && isReductionIterator(iteratorTypes[i])) {
+        LDBG("Non-trailing reduction dim requested for scalable "
+             "vectorization\n");
+        return failure();
+      }
+    }
+    return success();
+  }
+
   bool isScalable = inputScalableVecDims.back();
   if (!isScalable)
     return success();
 
   // Only element-wise and 1d depthwise conv ops supported in the presence of
   // scalable dims.
-  auto linalgOp = dyn_cast<LinalgOp>(op);
   return success(linalgOp && (isElementwise(linalgOp) ||
                               isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
----------------
zhaoshiz wrote:

reduce ops are checked above by new lines L1961 ~ L1975

https://github.com/llvm/llvm-project/pull/97788


More information about the Mlir-commits mailing list