From 41f54ef6a73af1e094689e36f35f377197c513fd Mon Sep 17 00:00:00 2001
From: Martin Lang <martin.lang@mpsd.mpg.de>
Date: Thu, 27 Feb 2025 10:13:28 +0100
Subject: [PATCH] Record intel modules to patch and always patch all intel
 modules

---
 src/mpsd_software_manager/spack.py | 53 ++++++++++++++++++++++--------
 1 file changed, 40 insertions(+), 13 deletions(-)

diff --git a/src/mpsd_software_manager/spack.py b/src/mpsd_software_manager/spack.py
index 1aa4c37..2c08de4 100644
--- a/src/mpsd_software_manager/spack.py
+++ b/src/mpsd_software_manager/spack.py
@@ -15,6 +15,7 @@ from pathlib import Path
 from typing import Any, Callable
 
 import jinja2
+import yaml
 
 from .config import Config
 from .util import abort
@@ -307,6 +308,9 @@ def spack_install_package(package: str) -> None:
 
 def refresh_modules(compilers: dict[str, Any] | None = None) -> None:
     """Create lmod modules and change family of intel compiler modules."""
+    CLASSIC_FAMILY = "intel_classic_compiler"
+    ONEAPI_FAMILY = "intel_oneapi_compiler"
+
     logger.info("refreshing lmod modules")
     try:
         spack("module lmod refresh -y", log_callback=logger.debug)
@@ -317,16 +321,25 @@ def refresh_modules(compilers: dict[str, Any] | None = None) -> None:
     # - to allow loading gcc and intel compiler simultaneously, we replace
     #   'family("compiler")' in the intel/oneapi lmod files
     # - add gcc as dependent module to intel compilers
+    to_patch_file = Config().spack_root / "etc" / "mpsd_intel_module_patching.yaml"
+
+    # record new intel module(s) to patch
     if compilers and "intel" in compilers["default"]["package"]:
+        try:
+            with open(to_patch_file) as f:
+                modules_to_patch = yaml.load(f, Loader=yaml.Loader)
+        except FileNotFoundError:
+            modules_to_patch = {}
+
         intel_module = (
             compilers["default"]["package"].split("%")[0].replace("@", "/") + ".lua"
         )
         gcc_module = compilers["fallback"]["package"].split("%")[0].replace("@", "/")
         module_file = Config().lmod_root / "Core" / intel_module
-        CLASSIC_FAMILY = "intel_classic_compiler"
-        ONEAPI_FAMILY = "intel_oneapi_compiler"
         family = CLASSIC_FAMILY if "classic" in intel_module else ONEAPI_FAMILY
-        patch_intel_module(module_file, gcc_module, family)
+
+        modules_to_patch[module_file] = {"family": family, "gcc_module": gcc_module}
+
         if family == "intel_classic_compiler":
             # We need to also patch the intel module, which is loaded as a dependency of
             # the intel classic module. We read the module name from the intel classic
@@ -338,26 +351,40 @@ def refresh_modules(compilers: dict[str, Any] | None = None) -> None:
             # match.group(1) will fail should we not find the line (should never happen)
             oneapi_module = match.group(1) + ".lua"
             oneapi_module_file = Config().lmod_root / "Core" / oneapi_module
-            patch_intel_module(oneapi_module_file, gcc_module, ONEAPI_FAMILY)
+            modules_to_patch[oneapi_module_file] = {
+                "family": ONEAPI_FAMILY,
+                "gcc_module": gcc_module,
+            }
+
+        with open(to_patch_file, "w") as f:
+            yaml.dump(modules_to_patch, f)
+
+    # patch all intel modules
+    try:
+        with open(to_patch_file) as f:
+            modules_to_patch = yaml.load(f)
+    except FileNotFoundError:
+        logger.debug("No intel modules to patch")
+        return
+
+    for module_file in modules_to_patch:
+        gcc_module = modules_to_patch["module_file"]["gcc_module"]
+        family = modules_to_patch["module_file"]["family"]
+        patch_intel_module(module_file, gcc_module, family)
 
 
 def patch_intel_module(module_file: Path, gcc_module: str, family: str) -> None:
     """Make intel modules depend on gcc and change family."""
-    logger.debug(
-        "Updating family(...) in '%s'",
-        module_file,
-    )
+    if not module_file.exists():
+        logger.warning("Module '%s' does not exist; skipping", module_file)
+    logger.debug("Updating family(...) in '%s'", module_file)
     content = module_file.read_text()
     content = content.replace('family("compiler")', f'family("{family}")')
     # Insert gcc as dependency before other dependencies and before modifying
     # MODULE_PATH.
     # This function can be called multiple times, so only insert if not yet present.
     if f'depends_on("{gcc_module}")\n' not in content:
-        logger.debug(
-            "Inserting dependency on '%s' in '%s'",
-            gcc_module,
-            module_file,
-        )
+        logger.debug("Inserting dependency on '%s' in '%s'", gcc_module, module_file)
         insertion_point = min(
             content.find("depends_on("), content.find('prepend_path("MODULEPATH"')
         )
-- 
GitLab