diff --git a/src/git_sim/commands.py b/src/git_sim/commands.py index 61dc588..03a8efe 100644 --- a/src/git_sim/commands.py +++ b/src/git_sim/commands.py @@ -259,14 +259,33 @@ def push( def rebase( - branch: str = typer.Argument( + newbase: str = typer.Argument( ..., - help="The branch to simulate rebasing the checked-out commit onto", - ) + help="The new base commit to simulate rebasing to", + ), + rebase_merges: bool = typer.Option( + False, + "--rebase-merges", + "-r", + help="Preserve merge structure during rebase", + ), + onto: bool = typer.Option( + False, + "--onto", + help="Rebase onto given branch instead of upstream", + ), + oldbase: str = typer.Argument( + None, + help="The parent of the commit to rebase (to be used with --onto)", + ), + until: str = typer.Argument( + None, + help="The commit to rebase up to and including (to be used with --onto)", + ), ): from git_sim.rebase import Rebase - scene = Rebase(branch=branch) + scene = Rebase(newbase=newbase, rebase_merges=rebase_merges, onto=onto, oldbase=oldbase, until=until) handle_animations(scene=scene) diff --git a/src/git_sim/config.py b/src/git_sim/config.py index c901006..9669b34 100644 --- a/src/git_sim/config.py +++ b/src/git_sim/config.py @@ -12,7 +12,7 @@ from typing import List from git.repo import Repo from argparse import Namespace -from configparser import NoSectionError +from configparser import NoSectionError, NoOptionError from git.exc import GitCommandError, InvalidGitRepositoryError from git_sim.settings import settings @@ -195,6 +195,9 @@ def add_details(self): except NoSectionError: print(f"git-sim error: section '{section}' doesn't exist in config") sys.exit(1) + except NoOptionError: + print(f"git-sim error: option '{option}' doesn't exist in config") + sys.exit(1) elif len(self.settings) == 2: value = self.settings[1].strip('"').strip("'").strip("\\") section_text = ( diff --git a/src/git_sim/git_sim_base_command.py b/src/git_sim/git_sim_base_command.py index c07a8e1..087b270 100644 --- a/src/git_sim/git_sim_base_command.py +++ b/src/git_sim/git_sim_base_command.py @@ -1,18 +1,22 @@ import os -import platform -import shutil -import stat +import git import sys +import stat +import numpy +import random +import shutil +import platform import tempfile -import git import manim as m -import numpy -from git.exc import GitCommandError, InvalidGitRepositoryError + from git.repo import Repo +from git.exc import GitCommandError, InvalidGitRepositoryError, BadName + +from collections import deque -from git_sim.enums import ColorByOptions, StyleOptions from git_sim.settings import settings +from git_sim.enums import ColorByOptions, StyleOptions class GitSimBaseCommand(m.MovingCameraScene): @@ -104,7 +108,11 @@ def construct(self): def get_commit(self, sha_or_ref="HEAD"): if self.head_exists(): - return self.repo.commit(sha_or_ref) + try: + return self.repo.commit(sha_or_ref) + except BadName: + print(f"git-sim error: {sha_or_ref} did not resolve to a valid Git object.") + sys.exit(1) return "dark" def get_default_commits(self): @@ -402,7 +410,7 @@ def get_nonparent_branch_names(self): exclude = [] for b1 in branches: for b2 in branches: - if b1.name != b2.name: + if b1.name != b2.name and b1.commit != b2.commit: if self.repo.is_ancestor(b1.commit, b2.commit): exclude.append(b1.name) return [b for b in branches if b.name not in exclude] @@ -947,7 +955,7 @@ def center_frame_on_commit(self, commit): else: self.camera.frame.move_to(self.drawnCommits[commit.hexsha].get_center()) - def reset_head_branch(self, hexsha, shift=numpy.array([0.0, 0.0, 0.0])): + def reset_head_branch(self, hexsha, branch="HEAD", shift=numpy.array([0.0, 0.0, 0.0])): if not self.head_exists(): return @@ -960,7 +968,7 @@ def reset_head_branch(self, hexsha, shift=numpy.array([0.0, 0.0, 0.0])): 0, ) ), - self.drawnRefs[self.repo.active_branch.name].animate.move_to( + self.drawnRefs[self.repo.active_branch.name if branch == "HEAD" else branch].animate.move_to( ( self.drawnCommits[hexsha].get_center()[0] + shift[0], self.drawnCommits[hexsha].get_center()[1] + 2 + shift[1], @@ -976,7 +984,7 @@ def reset_head_branch(self, hexsha, shift=numpy.array([0.0, 0.0, 0.0])): 0, ) ) - self.drawnRefs[self.repo.active_branch.name].move_to( + self.drawnRefs[self.repo.active_branch.name if branch == "HEAD" else branch].move_to( ( self.drawnCommits[hexsha].get_center()[0] + shift[0], self.drawnCommits[hexsha].get_center()[1] + 2 + shift[1], @@ -1323,12 +1331,17 @@ def add_group_to_author_groups(self, author, group): def show_command_as_title(self): if settings.show_command_as_title: - titleText = m.Text( - self.trim_cmd(self.cmd), - font=self.font, - font_size=36, - color=self.fontColor, - ) + title_len = 100 + while 1: + titleText = m.Text( + self.trim_cmd(self.cmd, title_len), + font=self.font, + font_size=36, + color=self.fontColor, + ) + if titleText.width < self.camera.frame.width: + break + title_len -= 5 top = 0 for element in self.toFadeOut: if element.get_top()[1] > top: @@ -1345,6 +1358,7 @@ def show_command_as_title(self): color=self.fontColor, ) self.toFadeOut.add(titleText, ul) + self.recenter_frame() self.scale_frame() if settings.animate: self.play(m.AddTextLetterByLetter(titleText), m.Create(ul)) @@ -1375,6 +1389,47 @@ def add_ref_to_drawn_refs_by_commit(self, hexsha, ref): ref, ] + def generate_random_sha(self): + valid_chars = "0123456789abcdef" + return "".join(random.choices(valid_chars, k=6)) + + def get_shortest_distance(self, sha_or_ref1, sha_or_ref2): + # Create a queue for BFS that stores (commit, depth) tuples + queue = deque([(self.repo.commit(sha_or_ref2), 0)]) + visited = set() + + # Perform BFS from the start commit + while queue: + current_commit, depth = queue.popleft() + + # If we reach the end commit + if current_commit.hexsha == self.repo.commit(sha_or_ref1).hexsha: + return depth + + # Mark this commit as visited + visited.add(current_commit.hexsha) + + # Queue all unvisited parents + for parent in current_commit.parents: + if parent.hexsha not in visited: + queue.append((parent, depth + 1)) + + # If no path found + return -1 + + def is_on_mainline(self, sha_or_ref1, sha_or_ref2): + current_commit = self.get_commit(sha_or_ref2) + + # Traverse the first parent history + while current_commit: + if current_commit.hexsha == self.get_commit(sha_or_ref1).hexsha: + return True + if current_commit.parents: + current_commit = current_commit.parents[0] + else: + break + return False + class DottedLine(m.Line): def __init__(self, *args, dot_spacing=0.4, dot_kwargs={}, **kwargs): diff --git a/src/git_sim/rebase.py b/src/git_sim/rebase.py index a455ae3..d76e734 100644 --- a/src/git_sim/rebase.py +++ b/src/git_sim/rebase.py @@ -1,106 +1,196 @@ import sys import git -import manim as m import numpy +import random + +import manim as m -from git_sim.git_sim_base_command import GitSimBaseCommand from git_sim.settings import settings +from git_sim.git_sim_base_command import GitSimBaseCommand, DottedLine class Rebase(GitSimBaseCommand): - def __init__(self, branch: str): + def __init__(self, newbase: str, rebase_merges: bool, onto: bool, oldbase: str, until: str): super().__init__() - self.branch = branch + self.newbase = newbase + self.rebase_merges = rebase_merges + self.onto = onto + self.oldbase = oldbase + self.until = until + + self.non_merge_reached = False try: - git.repo.fun.rev_parse(self.repo, self.branch) + git.repo.fun.rev_parse(self.repo, self.newbase) except git.exc.BadName: print( "git-sim error: '" - + self.branch + + self.newbase + "' is not a valid Git ref or identifier." ) sys.exit(1) - if self.branch in [branch.name for branch in self.repo.heads]: - self.selected_branches.append(self.branch) + if self.onto: + if not self.oldbase: + print( + "git-sim error: When using --onto, please specify as the parent of the commit to rebase" + ) + sys.exit(1) + elif not self.is_on_mainline(self.oldbase, "HEAD"): + print( + "git-sim error: Currently only mainline commit paths (i.e. paths traced following _first parents_) along to HEAD are supported for rebase simulations" + ) + sys.exit(1) + self.n = max(self.get_shortest_distance(self.oldbase, "HEAD"), self.n) + + if self.until: + self.until_n = self.get_shortest_distance(self.oldbase, self.until) + else: + if self.oldbase or self.until: + print( + "git-sim error: Please use --onto flag when specifying and " + ) + sys.exit(1) + + if self.newbase in [branch.name for branch in self.repo.heads]: + self.selected_branches.append(self.newbase) try: self.selected_branches.append(self.repo.active_branch.name) except TypeError: pass - self.cmd += f"{type(self).__name__.lower()} {self.branch}" + self.cmd += f"{type(self).__name__.lower()}{' --rebase-merges' if self.rebase_merges else ''}{' --onto' if self.onto else ''} {self.newbase}{' ' + self.oldbase if self.onto and self.oldbase else ''}{' ' + self.until if self.onto and self.until else ''}" + + self.alt_colors = { + 0: [m.BLUE_B, m.BLUE_E], + 1: [m.PURPLE_B, m.PURPLE_E], + 2: [m.GOLD_B, m.GOLD_E], + 3: [m.TEAL_B, m.TEAL_E], + 4: [m.MAROON_B, m.MAROON_E], + 5: [m.GREEN_B, m.GREEN_E], + } def construct(self): if not settings.stdout and not settings.output_only_path and not settings.quiet: print(f"{settings.INFO_STRING} {self.cmd}") - if self.branch in self.repo.git.branch( - "--contains", self.repo.active_branch.name - ): + if self.repo.is_ancestor(self.repo.head.commit.hexsha, self.newbase): print( - "git-sim error: Branch '" - + self.repo.active_branch.name - + "' is already included in the history of active branch '" - + self.branch + "git-sim error: Current HEAD '" + + self.repo.head.commit.hexsha + + "' is already included in the history of '" + + self.newbase + "'." ) sys.exit(1) - if self.repo.active_branch.name in self.repo.git.branch( - "--contains", self.branch - ): + if self.repo.is_ancestor(self.newbase, self.repo.head.commit.hexsha): print( - "git-sim error: Branch '" - + self.branch - + "' is already based on active branch '" - + self.repo.active_branch.name + "git-sim error: New base '" + + self.newbase + + "' is already based on current HEAD '" + + self.repo.head.commit.hexsha + "'." ) sys.exit(1) self.show_intro() - branch_commit = self.get_commit(self.branch) - self.parse_commits(branch_commit) + newbase_commit = self.get_commit(self.newbase) + self.parse_commits(newbase_commit) head_commit = self.get_commit() - - reached_base = False - for commit in self.get_default_commits(): - if commit != "dark" and self.branch in self.repo.git.branch( - "--contains", commit - ): - reached_base = True + default_commits = {} + self.get_default_commits(self.get_commit(), default_commits) + flat_default_commits = self.sort_and_flatten(default_commits) self.parse_commits(head_commit, shift=4 * m.DOWN) self.parse_all() - self.center_frame_on_commit(branch_commit) + self.center_frame_on_commit(newbase_commit) - to_rebase = [] - i = 0 - current = head_commit - while self.branch not in self.repo.git.branch("--contains", current): - to_rebase.append(current) - i += 1 - if i >= self.n: - break - current = self.get_default_commits()[i] + self.to_rebase = [] + for c in flat_default_commits: + if not self.repo.is_ancestor(c, self.newbase): + if self.onto and self.until: + range_commits = list(self.repo.iter_commits(f"{self.oldbase}...{self.until}")) + if c in range_commits: + self.to_rebase.append(c) + else: + self.to_rebase.append(c) + if self.rebase_merges: + if len(self.to_rebase[-1].parents) > 1: + self.cleaned_to_rebase = [] + for j, tr in enumerate(self.to_rebase): + if self.repo.is_ancestor(self.to_rebase[-1], tr): + self.cleaned_to_rebase.append(tr) + self.to_rebase = self.cleaned_to_rebase - parent = branch_commit.hexsha + reached_base = False + merge_base = self.repo.git.merge_base(self.newbase, self.repo.head.commit.hexsha) + base_branch_commits = list(self.repo.iter_commits(f"{merge_base}...HEAD")) + for bc in base_branch_commits: + if merge_base in [p.hexsha for p in bc.parents]: + reached_base = True + if merge_base in self.drawnCommits or (self.onto and self.to_rebase[-1].hexsha in self.drawnCommits): + reached_base = True - for j, tr in enumerate(reversed(to_rebase)): + parent = newbase_commit.hexsha + branch_counts = {} + rebased_shas = [] + rebased_sha_map = {} + for j, tr in enumerate(reversed(self.to_rebase)): + if not self.rebase_merges: + if len(tr.parents) > 1: + continue if not reached_base and j == 0: message = "..." else: message = tr.message - parent = self.setup_and_draw_parent(parent, message) - self.draw_arrow_between_commits(tr.hexsha, parent) + color_index = int(self.drawnCommits[tr.hexsha].get_center()[1] / -4) - 1 + if color_index not in branch_counts: + branch_counts[color_index] = 0 + branch_counts[color_index] += 1 + commit_color = self.alt_colors[color_index % len(self.alt_colors)][1] + parent = self.setup_and_draw_parent(parent, tr.hexsha, message, color=commit_color, branch_index=color_index, default_commits=default_commits) + rebased_shas.append(parent) + rebased_sha_map[tr.hexsha] = parent self.recenter_frame() self.scale_frame() - self.reset_head_branch(parent) - self.color_by(offset=2 * len(to_rebase)) + + branch_counts = {} + k = 0 + for j, tr in enumerate(reversed(self.to_rebase)): + if not self.rebase_merges: + if len(tr.parents) > 1: + k += 1 + continue + color_index = int(self.drawnCommits[tr.hexsha].get_center()[1] / -4) - 1 + if color_index not in branch_counts: + branch_counts[color_index] = 0 + branch_counts[color_index] += 1 + commit_color = self.alt_colors[color_index % len(self.alt_colors)][1] + arrow_color = self.alt_colors[color_index % len(self.alt_colors)][1 if branch_counts[color_index] % 2 == 0 else 1] + self.draw_arrow_between_commits(tr.hexsha, rebased_shas[j - k], color=arrow_color) + + if self.onto and self.until: + until_sha = self.get_commit(self.until).hexsha + if self.until in [b.name for b in self.repo.branches]: + self.reset_head_branch(rebased_sha_map[until_sha], branch=self.until) + else: + try: + self.reset_head(rebased_sha_map[until_sha]) + except KeyError: + for sha in rebased_sha_map: + if len(self.get_commit(sha).parents) < 2: + self.reset_head(rebased_sha_map[sha]) + break + elif self.rebase_merges: + self.reset_head_branch(rebased_sha_map[default_commits[0][0].hexsha]) + else: + self.reset_head_branch(parent) + + self.color_by(offset=2 * len(self.to_rebase)) self.show_command_as_title() self.fadeout() self.show_outro() @@ -108,46 +198,88 @@ def construct(self): def setup_and_draw_parent( self, child, + orig, commitMessage="New commit", shift=numpy.array([0.0, 0.0, 0.0]), draw_arrow=True, + color=m.RED, + branch_index=0, + default_commits={}, ): circle = m.Circle( - stroke_color=m.RED, + stroke_color=color, stroke_width=self.commit_stroke_width, - fill_color=m.RED, + fill_color=color, fill_opacity=0.25, ) circle.height = 1 - circle.next_to( - self.drawnCommits[child], - m.LEFT if settings.reverse else m.RIGHT, - buff=1.5, - ) + side_offset = 0 + num_branch_index_0_to_rebase = 0 + for commit in default_commits[0]: + if commit in self.to_rebase: + num_branch_index_0_to_rebase += 1 + if self.rebase_merges: + for bi in default_commits: + if bi > 0: + if len(default_commits[bi]) >= num_branch_index_0_to_rebase: + side_offset = len(default_commits[bi]) - num_branch_index_0_to_rebase + 1 + + if self.rebase_merges: + circle.move_to( + self.drawnCommits[orig].get_center(), + ).shift(m.UP * 4 + (m.LEFT if settings.reverse else m.RIGHT) * len(default_commits[0]) * 2.5 + (m.LEFT if settings.reverse else m.RIGHT) * (5 + side_offset)) + else: + circle.next_to( + self.drawnCommits[child], + m.LEFT if settings.reverse else m.RIGHT, + buff=1.5, + ) circle.shift(shift) - start = circle.get_center() - end = self.drawnCommits[child].get_center() - arrow = m.Arrow( - start, - end, - color=self.fontColor, - stroke_width=self.arrow_stroke_width, - tip_shape=self.arrow_tip_shape, - max_stroke_width_to_length_ratio=1000, - ) - length = numpy.linalg.norm(start - end) - (1.5 if start[1] == end[1] else 3) - arrow.set_length(length) + arrow_start_ends = set() + arrows = [] + start = tuple(circle.get_center()) + if not self.rebase_merges or branch_index == 0: + end = tuple(self.drawnCommits[child].get_center()) + arrow_start_ends.add((start, end)) + if self.rebase_merges: + orig_commit = self.get_commit(orig) + if len(orig_commit.parents) < 2: + self.non_merge_reached = True + for p in orig_commit.parents: + if self.repo.is_ancestor(p, self.newbase): + continue + try: + if p not in self.to_rebase: + if branch_index > 0: + end = tuple(self.drawnCommits[self.get_commit(self.newbase).hexsha].get_center()) + elif not self.non_merge_reached: + if p.hexsha == orig_commit.parents[1].hexsha: + end = tuple(self.drawnCommits[p.hexsha].get_center()) + else: + continue + else: + end = tuple(self.drawnCommits[p.hexsha].get_center() + m.UP * 4 + (m.LEFT if settings.reverse else m.RIGHT) * len(default_commits[0]) * 2.5 + (m.LEFT if settings.reverse else m.RIGHT) * (5 + side_offset)) + arrow_start_ends.add((start, end)) + except KeyError: + pass - sha = "".join( - chr(ord(letter) + 1) - if ( - (chr(ord(letter) + 1).isalpha() and letter < "f") - or chr(ord(letter) + 1).isdigit() + for start, end in arrow_start_ends: + arrow = m.Arrow( + start, + end, + color=self.fontColor, + stroke_width=self.arrow_stroke_width, + tip_shape=self.arrow_tip_shape, + max_stroke_width_to_length_ratio=1000, ) - else letter - for letter in child[:6] - ) + length = numpy.linalg.norm(numpy.subtract(end, start)) - (1.5 if start[1] == end[1] else 3) + arrow.set_length(length) + arrows.append(arrow) + + sha = None + while not sha or sha in self.drawnCommits: + sha = self.generate_random_sha() commitId = m.Text( sha if commitMessage != "..." else "...", font=self.font, @@ -184,9 +316,39 @@ def setup_and_draw_parent( if draw_arrow: if settings.animate: - self.play(m.Create(arrow), run_time=1 / settings.speed) + for arrow in arrows: + self.play(m.Create(arrow), run_time=1 / settings.speed) + self.toFadeOut.add(arrow) else: - self.add(arrow) - self.toFadeOut.add(arrow) + for arrow in arrows: + self.add(arrow) + self.toFadeOut.add(arrow) return sha + + def get_default_commits(self, commit, default_commits, branch_index=0): + if branch_index not in default_commits: + default_commits[branch_index] = [] + if len(default_commits[branch_index]) < self.n: + if self.onto and commit.hexsha == self.get_commit(self.oldbase).hexsha: + return default_commits + if commit not in self.sort_and_flatten(default_commits) and not self.repo.is_ancestor(commit, self.newbase): + default_commits[branch_index].append(commit) + for i, parent in enumerate(commit.parents): + self.get_default_commits(parent, default_commits, branch_index + i) + return default_commits + + def draw_arrow_between_commits(self, startsha, endsha, color): + start = self.drawnCommits[startsha].get_center() + end = self.drawnCommits[endsha].get_center() + + arrow = DottedLine( + start, end, color=color, dot_kwargs={"color": color} + ).add_tip() + length = numpy.linalg.norm(start - end) - 1.65 + arrow.set_length(length) + self.draw_arrow(True, arrow) + + def sort_and_flatten(self, d): + sorted_values = [d[key] for key in sorted(d.keys(), reverse=True)] + return sum(sorted_values, []) diff --git a/src/git_sim/settings.py b/src/git_sim/settings.py index 752b56b..2fbbd64 100644 --- a/src/git_sim/settings.py +++ b/src/git_sim/settings.py @@ -20,8 +20,8 @@ class Settings(BaseSettings): transparent_bg: bool = False logo: pathlib.Path = pathlib.Path(__file__).parent.resolve() / "logo.png" low_quality: bool = False - max_branches_per_commit: int = 1 - max_tags_per_commit: int = 1 + max_branches_per_commit: int = 2 + max_tags_per_commit: int = 2 media_dir: pathlib.Path = pathlib.Path().cwd() outro_bottom_text: str = "Learn more at initialcommit.com" outro_top_text: str = "Thanks for using Initial Commit!" diff --git a/src/git_sim/tag.py b/src/git_sim/tag.py index fa9eee3..4eeda9c 100644 --- a/src/git_sim/tag.py +++ b/src/git_sim/tag.py @@ -43,9 +43,9 @@ def construct(self): print(f"{settings.INFO_STRING} {self.cmd}") self.show_intro() - self.parse_commits() + self.parse_commits(self.get_commit(sha_or_ref=self.name) if self.d else None) self.parse_all() - self.center_frame_on_commit(self.get_commit()) + self.center_frame_on_commit(self.get_commit(sha_or_ref=self.name) if self.d else self.get_commit()) if not self.d: tagText = m.Text(