gabriel / muse public
config.py python
796 lines 24.3 KB
80353726 feat: muse auth + hub + config — paradigm-level identity architecture w… Gabriel Cardona <cgcardona@gmail.com> 4d ago
1 """Muse CLI configuration helpers.
2
3 Reads and writes ``.muse/config.toml`` — the per-repository configuration
4 file. Credentials (bearer tokens) are **not** stored here; they live in
5 ``~/.muse/identity.toml`` managed by :mod:`muse.core.identity`.
6
7 Config schema
8 -------------
9 ::
10
11 [user]
12 name = "Alice" # display name (human or agent handle)
13 email = "a@example.com"
14 type = "human" # "human" | "agent"
15
16 [hub]
17 url = "https://musehub.ai" # MuseHub fabric endpoint for this repo
18
19 [remotes.origin]
20 url = "https://hub.muse.io/repos/my-repo"
21 branch = "main"
22
23 [domain]
24 # Domain-specific key/value pairs; read by the active domain plugin.
25 # ticks_per_beat = "480"
26
27 Settable via ``muse config set``
28 ---------------------------------
29 - ``user.name``, ``user.email``, ``user.type``
30 - ``hub.url`` (alias: ``muse hub connect <url>``)
31 - ``domain.*``
32
33 Not settable via ``muse config set``
34 --------------------------------------
35 - ``remotes.*`` — use ``muse remote add/remove``
36 - credentials — use ``muse auth login``
37
38 Token resolution
39 ----------------
40 :func:`get_auth_token` reads the hub URL from this file, then resolves the
41 bearer token from ``~/.muse/identity.toml`` via
42 :func:`muse.core.identity.resolve_token`. The token is **never** logged.
43 """
44
45 from __future__ import annotations
46
47 import logging
48 import pathlib
49 import shutil
50 import subprocess
51 import tomllib
52 from typing import TypedDict
53
54 logger = logging.getLogger(__name__)
55
56 _CONFIG_FILENAME = "config.toml"
57 _MUSE_DIR = ".muse"
58
59
60 # ---------------------------------------------------------------------------
61 # Named configuration types
62 # ---------------------------------------------------------------------------
63
64
65 class UserConfig(TypedDict, total=False):
66 """``[user]`` section in ``.muse/config.toml``."""
67
68 name: str
69 email: str
70 type: str # "human" | "agent"
71
72
73 class HubConfig(TypedDict, total=False):
74 """``[hub]`` section in ``.muse/config.toml``."""
75
76 url: str
77
78
79 class RemoteEntry(TypedDict, total=False):
80 """``[remotes.<name>]`` section in ``.muse/config.toml``."""
81
82 url: str
83 branch: str
84
85
86 class MuseConfig(TypedDict, total=False):
87 """Structured view of the entire ``.muse/config.toml`` file."""
88
89 user: UserConfig
90 hub: HubConfig
91 remotes: dict[str, RemoteEntry]
92 domain: dict[str, str]
93
94
95 class RemoteConfig(TypedDict):
96 """Public-facing remote descriptor returned by :func:`list_remotes`."""
97
98 name: str
99 url: str
100
101
102 # ---------------------------------------------------------------------------
103 # Internal helpers
104 # ---------------------------------------------------------------------------
105
106
107 def _config_path(repo_root: pathlib.Path | None) -> pathlib.Path:
108 """Return the path to .muse/config.toml for the given (or cwd) root."""
109 root = (repo_root or pathlib.Path.cwd()).resolve()
110 return root / _MUSE_DIR / _CONFIG_FILENAME
111
112
113 def _load_config(config_path: pathlib.Path) -> MuseConfig:
114 """Load and parse config.toml; return an empty MuseConfig if absent."""
115 if not config_path.is_file():
116 return {}
117
118 try:
119 with config_path.open("rb") as fh:
120 raw = tomllib.load(fh)
121 except Exception as exc: # noqa: BLE001
122 logger.warning("⚠️ Failed to parse %s: %s", config_path, exc)
123 return {}
124
125 config: MuseConfig = {}
126
127 user_raw = raw.get("user")
128 if isinstance(user_raw, dict):
129 user: UserConfig = {}
130 name_val = user_raw.get("name")
131 if isinstance(name_val, str):
132 user["name"] = name_val
133 email_val = user_raw.get("email")
134 if isinstance(email_val, str):
135 user["email"] = email_val
136 type_val = user_raw.get("type")
137 if isinstance(type_val, str):
138 user["type"] = type_val
139 config["user"] = user
140
141 hub_raw = raw.get("hub")
142 if isinstance(hub_raw, dict):
143 hub: HubConfig = {}
144 url_val = hub_raw.get("url")
145 if isinstance(url_val, str):
146 hub["url"] = url_val
147 config["hub"] = hub
148
149 remotes_raw = raw.get("remotes")
150 if isinstance(remotes_raw, dict):
151 remotes: dict[str, RemoteEntry] = {}
152 for name, remote_raw in remotes_raw.items():
153 if isinstance(remote_raw, dict):
154 entry: RemoteEntry = {}
155 rurl = remote_raw.get("url")
156 if isinstance(rurl, str):
157 entry["url"] = rurl
158 branch_val = remote_raw.get("branch")
159 if isinstance(branch_val, str):
160 entry["branch"] = branch_val
161 remotes[name] = entry
162 config["remotes"] = remotes
163
164 domain_raw = raw.get("domain")
165 if isinstance(domain_raw, dict):
166 domain: dict[str, str] = {}
167 for key, val in domain_raw.items():
168 if isinstance(val, str):
169 domain[key] = val
170 config["domain"] = domain
171
172 return config
173
174
175 def _escape(value: str) -> str:
176 """Escape a TOML string value (backslash and double-quote)."""
177 return value.replace("\\", "\\\\").replace('"', '\\"')
178
179
180 def _dump_toml(config: MuseConfig) -> str:
181 """Serialise a MuseConfig to TOML text.
182
183 Section order: ``[user]``, ``[hub]``, ``[remotes.*]``, ``[domain]``.
184 """
185 lines: list[str] = []
186
187 user = config.get("user")
188 if user:
189 lines.append("[user]")
190 name = user.get("name", "")
191 if name:
192 lines.append(f'name = "{_escape(name)}"')
193 email = user.get("email", "")
194 if email:
195 lines.append(f'email = "{_escape(email)}"')
196 utype = user.get("type", "")
197 if utype:
198 lines.append(f'type = "{_escape(utype)}"')
199 lines.append("")
200
201 hub = config.get("hub")
202 if hub:
203 lines.append("[hub]")
204 url = hub.get("url", "")
205 if url:
206 lines.append(f'url = "{_escape(url)}"')
207 lines.append("")
208
209 remotes = config.get("remotes") or {}
210 for remote_name in sorted(remotes):
211 entry = remotes[remote_name]
212 lines.append(f"[remotes.{remote_name}]")
213 rurl = entry.get("url", "")
214 if rurl:
215 lines.append(f'url = "{_escape(rurl)}"')
216 branch = entry.get("branch", "")
217 if branch:
218 lines.append(f'branch = "{_escape(branch)}"')
219 lines.append("")
220
221 domain = config.get("domain") or {}
222 if domain:
223 lines.append("[domain]")
224 for key, val in sorted(domain.items()):
225 lines.append(f'{key} = "{_escape(val)}"')
226 lines.append("")
227
228 return "\n".join(lines)
229
230
231 # ---------------------------------------------------------------------------
232 # Auth token resolution (via identity store)
233 # ---------------------------------------------------------------------------
234
235
236 def get_auth_token(repo_root: pathlib.Path | None = None) -> str | None:
237 """Return the bearer token for this repository's configured hub.
238
239 Reads the hub URL from ``[hub] url`` in ``.muse/config.toml``, then
240 resolves the token from ``~/.muse/identity.toml`` via
241 :func:`muse.core.identity.resolve_token`.
242
243 Returns ``None`` when no hub is configured or no identity is stored for
244 that hub. The token value is **never** logged.
245
246 Args:
247 repo_root: Repository root. Defaults to ``Path.cwd()``.
248
249 Returns:
250 Bearer token string, or ``None``.
251 """
252 from muse.core.identity import resolve_token # avoid circular import at module level
253
254 hub_url = get_hub_url(repo_root)
255 if hub_url is None:
256 logger.debug("⚠️ No hub configured — skipping auth token lookup")
257 return None
258
259 token = resolve_token(hub_url)
260 if token is None:
261 logger.debug("⚠️ No identity for hub %s — run `muse auth login`", hub_url)
262 return None
263
264 logger.debug("✅ Auth token resolved for hub %s (Bearer ***)", hub_url)
265 return token
266
267
268 # ---------------------------------------------------------------------------
269 # Hub helpers
270 # ---------------------------------------------------------------------------
271
272
273 def get_hub_url(repo_root: pathlib.Path | None = None) -> str | None:
274 """Return the hub URL from ``[hub] url``, or ``None`` if not configured.
275
276 Args:
277 repo_root: Repository root. Defaults to ``Path.cwd()``.
278
279 Returns:
280 URL string, or ``None``.
281 """
282 config = _load_config(_config_path(repo_root))
283 hub = config.get("hub")
284 if hub is None:
285 return None
286 url = hub.get("url", "")
287 return url.strip() if url.strip() else None
288
289
290 def set_hub_url(url: str, repo_root: pathlib.Path | None = None) -> None:
291 """Write ``[hub] url`` to ``.muse/config.toml``.
292
293 Preserves all other sections. Creates the config file if absent.
294 Rejects ``http://`` URLs — Muse never contacts a hub over cleartext HTTP.
295
296 Args:
297 url: Hub URL (must be ``https://``).
298 repo_root: Repository root. Defaults to ``Path.cwd()``.
299
300 Raises:
301 ValueError: If *url* does not use the ``https://`` scheme.
302 """
303 if not url.startswith("https://"):
304 raise ValueError(
305 f"Hub URL must use HTTPS. Got: {url!r}\n"
306 "Muse never connects to a hub over cleartext HTTP."
307 )
308 cp = _config_path(repo_root)
309 cp.parent.mkdir(parents=True, exist_ok=True)
310 config = _load_config(cp)
311 config["hub"] = HubConfig(url=url)
312 cp.write_text(_dump_toml(config), encoding="utf-8")
313 logger.info("✅ Hub URL set to %s", url)
314
315
316 def clear_hub_url(repo_root: pathlib.Path | None = None) -> None:
317 """Remove the ``[hub]`` section from ``.muse/config.toml``.
318
319 Args:
320 repo_root: Repository root. Defaults to ``Path.cwd()``.
321 """
322 cp = _config_path(repo_root)
323 config = _load_config(cp)
324 if "hub" in config:
325 del config["hub"]
326 cp.write_text(_dump_toml(config), encoding="utf-8")
327 logger.info("✅ Hub disconnected")
328
329
330 # ---------------------------------------------------------------------------
331 # User config helpers
332 # ---------------------------------------------------------------------------
333
334
335 def get_user_config(repo_root: pathlib.Path | None = None) -> UserConfig:
336 """Return the ``[user]`` section, or an empty UserConfig if absent."""
337 config = _load_config(_config_path(repo_root))
338 return config.get("user") or {}
339
340
341 def set_user_field(key: str, value: str, repo_root: pathlib.Path | None = None) -> None:
342 """Set a single ``[user]`` field by name.
343
344 Allowed keys: ``name``, ``email``, ``type``.
345
346 Args:
347 key: Field name within ``[user]``.
348 value: New value.
349 repo_root: Repository root. Defaults to ``Path.cwd()``.
350
351 Raises:
352 ValueError: If *key* is not a recognised user config field.
353 """
354 if key not in {"name", "email", "type"}:
355 raise ValueError(f"Unknown [user] config key: {key!r}. Valid keys: name, email, type")
356 cp = _config_path(repo_root)
357 cp.parent.mkdir(parents=True, exist_ok=True)
358 config = _load_config(cp)
359 user: UserConfig = config.get("user") or {}
360 if key == "name":
361 user["name"] = value
362 elif key == "email":
363 user["email"] = value
364 elif key == "type":
365 user["type"] = value
366 config["user"] = user
367 cp.write_text(_dump_toml(config), encoding="utf-8")
368 logger.info("✅ user.%s = %r", key, value)
369
370
371 # ---------------------------------------------------------------------------
372 # Generic dotted-key helpers
373 # ---------------------------------------------------------------------------
374
375 _BLOCKED_NAMESPACES: dict[str, str] = {
376 "auth": "Use `muse auth login` to manage credentials.",
377 "remotes": "Use `muse remote add/remove/rename` to manage remotes.",
378 }
379
380 _SETTABLE_NAMESPACES = {"user", "hub", "domain"}
381
382
383 def get_config_value(key: str, repo_root: pathlib.Path | None = None) -> str | None:
384 """Get a config value by dotted key (e.g. ``user.name``, ``hub.url``).
385
386 Returns ``None`` when the key is not set or the namespace is unknown.
387
388 Args:
389 key: Dotted key in ``<namespace>.<subkey>`` form.
390 repo_root: Repository root. Defaults to ``Path.cwd()``.
391
392 Returns:
393 String value, or ``None``.
394 """
395 parts = key.split(".", 1)
396 if len(parts) != 2:
397 return None
398 namespace, subkey = parts
399 config = _load_config(_config_path(repo_root))
400
401 if namespace == "user":
402 user = config.get("user") or {}
403 if subkey == "name":
404 return user.get("name")
405 if subkey == "email":
406 return user.get("email")
407 if subkey == "type":
408 return user.get("type")
409 return None
410
411 if namespace == "hub":
412 hub = config.get("hub") or {}
413 if subkey == "url":
414 return hub.get("url")
415 return None
416
417 if namespace == "domain":
418 domain = config.get("domain") or {}
419 return domain.get(subkey)
420
421 return None
422
423
424 def set_config_value(key: str, value: str, repo_root: pathlib.Path | None = None) -> None:
425 """Set a config value by dotted key (e.g. ``user.name``, ``domain.ticks_per_beat``).
426
427 Args:
428 key: Dotted key in ``<namespace>.<subkey>`` form.
429 value: New string value.
430 repo_root: Repository root. Defaults to ``Path.cwd()``.
431
432 Raises:
433 ValueError: If the namespace is blocked, unknown, or the subkey is invalid.
434 """
435 parts = key.split(".", 1)
436 if len(parts) != 2:
437 raise ValueError(f"Key must be in 'namespace.subkey' form, got: {key!r}")
438 namespace, subkey = parts
439
440 if namespace in _BLOCKED_NAMESPACES:
441 raise ValueError(_BLOCKED_NAMESPACES[namespace])
442
443 if namespace not in _SETTABLE_NAMESPACES:
444 raise ValueError(
445 f"Unknown config namespace {namespace!r}. "
446 f"Settable namespaces: {', '.join(sorted(_SETTABLE_NAMESPACES))}"
447 )
448
449 cp = _config_path(repo_root)
450 cp.parent.mkdir(parents=True, exist_ok=True)
451 config = _load_config(cp)
452
453 if namespace == "user":
454 set_user_field(subkey, value, repo_root)
455 return
456
457 if namespace == "hub":
458 if subkey != "url":
459 raise ValueError(f"Unknown [hub] config key: {subkey!r}. Valid keys: url")
460 # Route through set_hub_url — it enforces the HTTPS requirement.
461 set_hub_url(value, repo_root)
462 return
463
464 # namespace == "domain"
465 domain: dict[str, str] = config.get("domain") or {}
466 domain[subkey] = value
467 config["domain"] = domain
468 cp.write_text(_dump_toml(config), encoding="utf-8")
469 logger.info("✅ domain.%s = %r", subkey, value)
470
471
472 def config_as_dict(repo_root: pathlib.Path | None = None) -> dict[str, dict[str, str]]:
473 """Return the full config as a plain ``dict[str, dict[str, str]]`` for JSON output.
474
475 Credentials are never included — the hub section only contains the URL.
476
477 Args:
478 repo_root: Repository root. Defaults to ``Path.cwd()``.
479
480 Returns:
481 Nested dict suitable for ``json.dumps``.
482 """
483 config = _load_config(_config_path(repo_root))
484 result: dict[str, dict[str, str]] = {}
485
486 user = config.get("user")
487 if user:
488 user_dict: dict[str, str] = {}
489 uname = user.get("name")
490 if uname:
491 user_dict["name"] = uname
492 uemail = user.get("email")
493 if uemail:
494 user_dict["email"] = uemail
495 utype = user.get("type")
496 if utype:
497 user_dict["type"] = utype
498 if user_dict:
499 result["user"] = user_dict
500
501 hub = config.get("hub")
502 if hub:
503 hub_url = hub.get("url", "")
504 if hub_url:
505 result["hub"] = {"url": hub_url}
506
507 remotes = config.get("remotes") or {}
508 if remotes:
509 remotes_dict: dict[str, str] = {}
510 for rname, entry in sorted(remotes.items()):
511 url = entry.get("url", "")
512 if url:
513 remotes_dict[rname] = url
514 if remotes_dict:
515 result["remotes"] = remotes_dict
516
517 domain = config.get("domain") or {}
518 if domain:
519 result["domain"] = dict(sorted(domain.items()))
520
521 return result
522
523
524 def config_path_for_editor(repo_root: pathlib.Path | None = None) -> pathlib.Path:
525 """Return the config path for the ``config edit`` command."""
526 return _config_path(repo_root)
527
528
529 # ---------------------------------------------------------------------------
530 # Remote helpers
531 # ---------------------------------------------------------------------------
532
533
534 def get_remote(name: str, repo_root: pathlib.Path | None = None) -> str | None:
535 """Return the URL for remote *name*, or ``None`` when not configured.
536
537 Args:
538 name: Remote name (e.g. ``"origin"``).
539 repo_root: Repository root. Defaults to ``Path.cwd()``.
540
541 Returns:
542 URL string, or ``None``.
543 """
544 config = _load_config(_config_path(repo_root))
545 remotes = config.get("remotes")
546 if remotes is None:
547 return None
548 entry = remotes.get(name)
549 if entry is None:
550 return None
551 url = entry.get("url", "")
552 return url.strip() if url.strip() else None
553
554
555 def set_remote(
556 name: str,
557 url: str,
558 repo_root: pathlib.Path | None = None,
559 ) -> None:
560 """Write ``[remotes.<name>] url`` to ``.muse/config.toml``.
561
562 Preserves all other sections. Creates the file if absent.
563
564 Args:
565 name: Remote name (e.g. ``"origin"``).
566 url: Remote URL.
567 repo_root: Repository root. Defaults to ``Path.cwd()``.
568 """
569 cp = _config_path(repo_root)
570 cp.parent.mkdir(parents=True, exist_ok=True)
571 config = _load_config(cp)
572 existing_remotes = config.get("remotes")
573 remotes: dict[str, RemoteEntry] = {}
574 if existing_remotes:
575 remotes.update(existing_remotes)
576 existing_entry = remotes.get(name)
577 entry: RemoteEntry = {}
578 if existing_entry is not None:
579 if "url" in existing_entry:
580 entry["url"] = existing_entry["url"]
581 if "branch" in existing_entry:
582 entry["branch"] = existing_entry["branch"]
583 entry["url"] = url
584 remotes[name] = entry
585 config["remotes"] = remotes
586 cp.write_text(_dump_toml(config), encoding="utf-8")
587 logger.info("✅ Remote %r set to %s", name, url)
588
589
590 def remove_remote(
591 name: str,
592 repo_root: pathlib.Path | None = None,
593 ) -> None:
594 """Remove a named remote and its tracking refs.
595
596 Args:
597 name: Remote name to remove.
598 repo_root: Repository root. Defaults to ``Path.cwd()``.
599
600 Raises:
601 KeyError: If *name* is not a configured remote.
602 """
603 cp = _config_path(repo_root)
604 config = _load_config(cp)
605 remotes = config.get("remotes")
606 if remotes is None or name not in remotes:
607 raise KeyError(name)
608 del remotes[name]
609 config["remotes"] = remotes
610 cp.write_text(_dump_toml(config), encoding="utf-8")
611 logger.info("✅ Remote %r removed from config", name)
612
613 root = (repo_root or pathlib.Path.cwd()).resolve()
614 refs_dir = root / _MUSE_DIR / "remotes" / name
615 if refs_dir.is_dir():
616 shutil.rmtree(refs_dir)
617 logger.debug("✅ Removed tracking refs dir %s", refs_dir)
618
619
620 def rename_remote(
621 old_name: str,
622 new_name: str,
623 repo_root: pathlib.Path | None = None,
624 ) -> None:
625 """Rename a remote and move its tracking refs.
626
627 Args:
628 old_name: Current remote name.
629 new_name: Desired new remote name.
630 repo_root: Repository root. Defaults to ``Path.cwd()``.
631
632 Raises:
633 KeyError: If *old_name* is not a configured remote.
634 ValueError: If *new_name* is already configured.
635 """
636 cp = _config_path(repo_root)
637 config = _load_config(cp)
638 remotes = config.get("remotes")
639 if remotes is None or old_name not in remotes:
640 raise KeyError(old_name)
641 if new_name in remotes:
642 raise ValueError(new_name)
643 remotes[new_name] = remotes.pop(old_name)
644 config["remotes"] = remotes
645 cp.write_text(_dump_toml(config), encoding="utf-8")
646 logger.info("✅ Remote %r renamed to %r", old_name, new_name)
647
648 root = (repo_root or pathlib.Path.cwd()).resolve()
649 old_refs_dir = root / _MUSE_DIR / "remotes" / old_name
650 new_refs_dir = root / _MUSE_DIR / "remotes" / new_name
651 if old_refs_dir.is_dir():
652 old_refs_dir.rename(new_refs_dir)
653 logger.debug("✅ Moved tracking refs dir %s → %s", old_refs_dir, new_refs_dir)
654
655
656 def list_remotes(repo_root: pathlib.Path | None = None) -> list[RemoteConfig]:
657 """Return all configured remotes sorted alphabetically by name.
658
659 Args:
660 repo_root: Repository root. Defaults to ``Path.cwd()``.
661
662 Returns:
663 List of ``{"name": str, "url": str}`` dicts.
664 """
665 config = _load_config(_config_path(repo_root))
666 remotes = config.get("remotes")
667 if remotes is None:
668 return []
669 result: list[RemoteConfig] = []
670 for remote_name in sorted(remotes):
671 entry = remotes[remote_name]
672 url = entry.get("url", "")
673 if url.strip():
674 result.append(RemoteConfig(name=remote_name, url=url.strip()))
675 return result
676
677
678 # ---------------------------------------------------------------------------
679 # Remote tracking-head helpers
680 # ---------------------------------------------------------------------------
681
682
683 def _remote_head_path(
684 remote_name: str,
685 branch: str,
686 repo_root: pathlib.Path | None = None,
687 ) -> pathlib.Path:
688 """Return the path to the remote tracking pointer file."""
689 root = (repo_root or pathlib.Path.cwd()).resolve()
690 return root / _MUSE_DIR / "remotes" / remote_name / branch
691
692
693 def get_remote_head(
694 remote_name: str,
695 branch: str,
696 repo_root: pathlib.Path | None = None,
697 ) -> str | None:
698 """Return the last-known remote commit ID for *remote_name*/*branch*.
699
700 Returns ``None`` when the tracking pointer does not exist.
701
702 Args:
703 remote_name: Remote name (e.g. ``"origin"``).
704 branch: Branch name (e.g. ``"main"``).
705 repo_root: Repository root. Defaults to ``Path.cwd()``.
706
707 Returns:
708 Commit ID string, or ``None``.
709 """
710 pointer = _remote_head_path(remote_name, branch, repo_root)
711 if not pointer.is_file():
712 return None
713 raw = pointer.read_text(encoding="utf-8").strip()
714 return raw if raw else None
715
716
717 def set_remote_head(
718 remote_name: str,
719 branch: str,
720 commit_id: str,
721 repo_root: pathlib.Path | None = None,
722 ) -> None:
723 """Write the remote tracking pointer for *remote_name*/*branch*.
724
725 Args:
726 remote_name: Remote name (e.g. ``"origin"``).
727 branch: Branch name.
728 commit_id: Commit ID to record as the known remote HEAD.
729 repo_root: Repository root. Defaults to ``Path.cwd()``.
730 """
731 pointer = _remote_head_path(remote_name, branch, repo_root)
732 pointer.parent.mkdir(parents=True, exist_ok=True)
733 pointer.write_text(commit_id, encoding="utf-8")
734 logger.debug("✅ Remote head %s/%s → %s", remote_name, branch, commit_id[:8])
735
736
737 # ---------------------------------------------------------------------------
738 # Upstream tracking helpers
739 # ---------------------------------------------------------------------------
740
741
742 def set_upstream(
743 branch: str,
744 remote_name: str,
745 repo_root: pathlib.Path | None = None,
746 ) -> None:
747 """Record *remote_name* as the upstream remote for *branch*.
748
749 Args:
750 branch: Local (and remote) branch name.
751 remote_name: Remote name.
752 repo_root: Repository root. Defaults to ``Path.cwd()``.
753 """
754 cp = _config_path(repo_root)
755 cp.parent.mkdir(parents=True, exist_ok=True)
756 config = _load_config(cp)
757 existing_remotes = config.get("remotes")
758 remotes: dict[str, RemoteEntry] = {}
759 if existing_remotes:
760 remotes.update(existing_remotes)
761 existing_entry = remotes.get(remote_name)
762 entry: RemoteEntry = {}
763 if existing_entry is not None:
764 if "url" in existing_entry:
765 entry["url"] = existing_entry["url"]
766 if "branch" in existing_entry:
767 entry["branch"] = existing_entry["branch"]
768 entry["branch"] = branch
769 remotes[remote_name] = entry
770 config["remotes"] = remotes
771 cp.write_text(_dump_toml(config), encoding="utf-8")
772 logger.info("✅ Upstream for branch %r set to %s/%r", branch, remote_name, branch)
773
774
775 def get_upstream(
776 branch: str,
777 repo_root: pathlib.Path | None = None,
778 ) -> str | None:
779 """Return the configured upstream remote name for *branch*, or ``None``.
780
781 Args:
782 branch: Local branch name.
783 repo_root: Repository root. Defaults to ``Path.cwd()``.
784
785 Returns:
786 Remote name string, or ``None``.
787 """
788 config = _load_config(_config_path(repo_root))
789 remotes = config.get("remotes")
790 if remotes is None:
791 return None
792 for rname, entry in remotes.items():
793 tracked = entry.get("branch", "")
794 if tracked.strip() == branch:
795 return rname
796 return None