[Mlir-commits] [mlir] [mlir][vector] NFC - Add more structured interface support to vector.contract (PR #145313)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jun 23 06:37:19 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145313

>From 91f7ea29a737e1ecd959ec64820264c4ecaedddc Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 23 Jun 2025 13:41:33 +0200
Subject: [PATCH 1/3] [mlir][vector] NFC - Add more structured interface
 support to vector.contract

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 33 +++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 926a92eff2ebb..12362be4d7d30 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -207,6 +207,39 @@ def Vector_ContractionOp :
               .template getAsValueRange<IteratorTypeAttr, IteratorType>();
       return {range.begin(), range.end()};
     }
+
+    //===------------------------------------------------------------------===//
+    // The code below is shared with LinalgStructuredInterface.
+    // vector.contract is really a linalg.generic on vectors without region.
+    // TODO: factor out in a common interface to inherit from ince identified.
+    //===------------------------------------------------------------------===//
+    ArrayRef<int64_t> getShape(OpOperand * opOperand) {
+      assert(opOperand->getOwner() == this->getOperation());
+      Type t = opOperand->get().getType();
+      return cast<VectorType>(t).getShape();
+    }
+
+    AffineMap getLoopsToShapesMap() {
+      auto maps = getIndexingMapsArray();
+      return concatAffineMaps(maps, getContext());
+    }
+
+    AffineMap getShapesToLoopsMap() {
+      return inversePermutation(getLoopsToShapesMap());
+    }
+
+    SmallVector<int64_t> getStaticShape(){
+      SmallVector<int64_t> res;
+      for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+        llvm::append_range(res, getShape(&opOperand));
+      return res;
+    }
+
+    SmallVector<int64_t> getStaticLoopRanges() {
+      SmallVector<int64_t> viewSizes = getStaticShape();
+      AffineMap invertedMap = getShapesToLoopsMap();
+      return invertedMap.compose(viewSizes);
+    }
   }];
 
   let hasCanonicalizer = 1;

>From 33effcd839a97e12cfcb42f6185a90c83cb3bbd2 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <Nico.Vasilache at amd.com>
Date: Mon, 23 Jun 2025 15:37:02 +0200
Subject: [PATCH 2/3] Update mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Co-authored-by: Fabian Mora <fmora.dev at gmail.com>
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 12362be4d7d30..583600384a39e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -228,7 +228,7 @@ def Vector_ContractionOp :
       return inversePermutation(getLoopsToShapesMap());
     }
 
-    SmallVector<int64_t> getStaticShape(){
+    SmallVector<int64_t> getStaticShape() {
       SmallVector<int64_t> res;
       for (OpOperand &opOperand : this->getOperation()->getOpOperands())
         llvm::append_range(res, getShape(&opOperand));

>From 31d8d860bfb769af816ef63ba4b06159ef20424a Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <Nico.Vasilache at amd.com>
Date: Mon, 23 Jun 2025 15:37:11 +0200
Subject: [PATCH 3/3] Update mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Co-authored-by: Fabian Mora <fmora.dev at gmail.com>
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 583600384a39e..f4bbc2843903c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -211,7 +211,7 @@ def Vector_ContractionOp :
     //===------------------------------------------------------------------===//
     // The code below is shared with LinalgStructuredInterface.
     // vector.contract is really a linalg.generic on vectors without region.
-    // TODO: factor out in a common interface to inherit from ince identified.
+    // TODO: factor out in a common interface to inherit from once identified.
     //===------------------------------------------------------------------===//
     ArrayRef<int64_t> getShape(OpOperand * opOperand) {
       assert(opOperand->getOwner() == this->getOperation());



More information about the Mlir-commits mailing list