Skip to content

Commit 851bad6

Browse files
Refactor rebase subcommand to use more general <newbase> and <oldbase> terminology, update commit path tracings algos for accuracy
Signed-off-by: Jacob Stopak <[email protected]>
1 parent a5542d3 commit 851bad6

File tree

3 files changed

+115
-66
lines changed

3 files changed

+115
-66
lines changed

‎src/git_sim/commands.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ def push(
259259

260260

261261
def rebase(
262-
branch: str = typer.Argument(
262+
newbase: str = typer.Argument(
263263
...,
264-
help="The branch to simulate rebasing the checked-out commit onto",
264+
help="The new base commit to simulate rebasing to",
265265
),
266266
rebase_merges: bool = typer.Option(
267267
False,
@@ -274,7 +274,7 @@ def rebase(
274274
"--onto",
275275
help="Rebase onto given branch instead of upstream",
276276
),
277-
oldparent: str = typer.Argument(
277+
oldbase: str = typer.Argument(
278278
None,
279279
help="The parent of the commit to rebase (to be used with --onto)",
280280
),
@@ -285,7 +285,7 @@ def rebase(
285285
):
286286
from git_sim.rebase import Rebase
287287

288-
scene = Rebase(branch=branch, rebase_merges=rebase_merges, onto=onto, oldparent=oldparent, until=until)
288+
scene = Rebase(newbase=newbase, rebase_merges=rebase_merges, onto=onto, oldbase=oldbase, until=until)
289289
handle_animations(scene=scene)
290290

291291

‎src/git_sim/git_sim_base_command.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from git.repo import Repo
1414
from git.exc import GitCommandError, InvalidGitRepositoryError, BadName
1515

16+
from collections import deque
17+
1618
from git_sim.settings import settings
1719
from git_sim.enums import ColorByOptions, StyleOptions
1820

@@ -1390,16 +1392,43 @@ def generate_random_sha(self):
13901392
valid_chars = "0123456789abcdef"
13911393
return "".join(random.choices(valid_chars, k=6))
13921394

1393-
def get_mainline_distance(self, sha_or_ref1, sha_or_ref2):
1394-
commit1 = self.get_commit(sha_or_ref1)
1395-
commit2 = self.get_commit(sha_or_ref2)
1396-
if not self.repo.is_ancestor(commit1, commit2):
1397-
print(f"git-sim error: specified sha/ref '{sha_or_ref1}' must be an ancestor of sha/ref '{sha_or_ref2}'.")
1398-
sys.exit(1)
1399-
d = 0
1400-
while self.get_commit(f"{commit2.hexsha}~{d}").hexsha != commit1.hexsha:
1401-
d += 1
1402-
return d
1395+
def get_shortest_distance(self, sha_or_ref1, sha_or_ref2):
1396+
# Create a queue for BFS that stores (commit, depth) tuples
1397+
queue = deque([(self.repo.commit(sha_or_ref2), 0)])
1398+
visited = set()
1399+
1400+
# Perform BFS from the start commit
1401+
while queue:
1402+
current_commit, depth = queue.popleft()
1403+
1404+
# If we reach the end commit
1405+
if current_commit.hexsha == self.repo.commit(sha_or_ref1).hexsha:
1406+
print(depth)
1407+
return depth
1408+
1409+
# Mark this commit as visited
1410+
visited.add(current_commit.hexsha)
1411+
1412+
# Queue all unvisited parents
1413+
for parent in current_commit.parents:
1414+
if parent.hexsha not in visited:
1415+
queue.append((parent, depth + 1))
1416+
1417+
# If no path found
1418+
return -1
1419+
1420+
def is_on_mainline(self, sha_or_ref1, sha_or_ref2):
1421+
current_commit = self.get_commit(sha_or_ref2)
1422+
1423+
# Traverse the first parent history
1424+
while current_commit:
1425+
if current_commit.hexsha == self.get_commit(sha_or_ref1).hexsha:
1426+
return True
1427+
if current_commit.parents:
1428+
current_commit = current_commit.parents[0]
1429+
else:
1430+
break
1431+
return False
14031432

14041433

14051434
class DottedLine(m.Line):

‎src/git_sim/rebase.py

+72-52
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,57 @@
1111

1212

1313
class Rebase(GitSimBaseCommand):
14-
def __init__(self, branch: str, rebase_merges: bool, onto: bool, oldparent: str, until: str):
14+
def __init__(self, newbase: str, rebase_merges: bool, onto: bool, oldbase: str, until: str):
1515
super().__init__()
16-
self.branch = branch
16+
self.newbase = newbase
1717
self.rebase_merges = rebase_merges
1818
self.onto = onto
19-
self.oldparent = oldparent
19+
self.oldbase = oldbase
2020
self.until = until
2121

22+
self.non_merge_reached = False
23+
2224
try:
23-
git.repo.fun.rev_parse(self.repo, self.branch)
25+
git.repo.fun.rev_parse(self.repo, self.newbase)
2426
except git.exc.BadName:
2527
print(
2628
"git-sim error: '"
27-
+ self.branch
29+
+ self.newbase
2830
+ "' is not a valid Git ref or identifier."
2931
)
3032
sys.exit(1)
3133

3234
if self.onto:
33-
if not self.oldparent:
35+
if not self.oldbase:
36+
print(
37+
"git-sim error: Please specify <oldbase> as the parent of the commit to rebase"
38+
)
39+
sys.exit(1)
40+
elif not self.is_on_mainline(self.oldbase, "HEAD"):
3441
print(
35-
"git-sim error: Please specify the parent of the commit to rebase ('oldparent')"
42+
"git-sim error: Currently only mainline commit paths (i.e. paths traced following _first parents_) along <oldbase> to HEAD are supported for rebase simulations"
3643
)
3744
sys.exit(1)
38-
self.n = max(self.get_mainline_distance(self.oldparent, "HEAD"), self.n)
45+
self.n = max(self.get_shortest_distance(self.oldbase, "HEAD"), self.n)
3946

4047
if self.until:
41-
self.until_n = self.get_mainline_distance(self.oldparent, self.until)
48+
self.until_n = self.get_shortest_distance(self.oldbase, self.until)
4249
else:
43-
if self.oldparent or self.until:
50+
if self.oldbase or self.until:
4451
print(
45-
"git-sim error: Please use --onto flag when specifying <oldparent> and <until>"
52+
"git-sim error: Please use --onto flag when specifying <oldbase> and <until>"
4653
)
4754
sys.exit(1)
4855

49-
if self.branch in [branch.name for branch in self.repo.heads]:
50-
self.selected_branches.append(self.branch)
56+
if self.newbase in [branch.name for branch in self.repo.heads]:
57+
self.selected_branches.append(self.newbase)
5158

5259
try:
5360
self.selected_branches.append(self.repo.active_branch.name)
5461
except TypeError:
5562
pass
5663

57-
self.cmd += f"{type(self).__name__.lower()}{' --rebase-merges' if self.rebase_merges else ''}{' --onto' if self.onto else ''} {self.branch}{' ' + self.oldparent if self.onto and self.oldparent else ''}{' ' + self.until if self.onto and self.until else ''}"
64+
self.cmd += f"{type(self).__name__.lower()}{' --rebase-merges' if self.rebase_merges else ''}{' --onto' if self.onto else ''} {self.newbase}{' ' + self.oldbase if self.onto and self.oldbase else ''}{' ' + self.until if self.onto and self.until else ''}"
5865

5966
self.alt_colors = {
6067
0: [m.BLUE_B, m.BLUE_E],
@@ -69,58 +76,61 @@ def construct(self):
6976
if not settings.stdout and not settings.output_only_path and not settings.quiet:
7077
print(f"{settings.INFO_STRING} {self.cmd}")
7178

72-
if self.branch in self.repo.git.branch(
73-
"--contains", self.repo.active_branch.name
74-
):
79+
if self.repo.is_ancestor(self.repo.head.commit.hexsha, self.newbase):
7580
print(
76-
"git-sim error: Branch '"
77-
+ self.repo.active_branch.name
78-
+ "' is already included in the history of active branch '"
79-
+ self.branch
81+
"git-sim error: Current HEAD '"
82+
+ self.repo.head.commit.hexsha
83+
+ "' is already included in the history of '"
84+
+ self.newbase
8085
+ "'."
8186
)
8287
sys.exit(1)
8388

84-
if self.repo.active_branch.name in self.repo.git.branch(
85-
"--contains", self.branch
86-
):
89+
if self.repo.is_ancestor(self.newbase, self.repo.head.commit.hexsha):
8790
print(
88-
"git-sim error: Branch '"
89-
+ self.branch
90-
+ "' is already based on active branch '"
91-
+ self.repo.active_branch.name
91+
"git-sim error: New base '"
92+
+ self.newbase
93+
+ "' is already based on current HEAD '"
94+
+ self.repo.head.commit.hexsha
9295
+ "'."
9396
)
9497
sys.exit(1)
9598

9699
self.show_intro()
97-
branch_commit = self.get_commit(self.branch)
98-
self.parse_commits(branch_commit)
100+
newbase_commit = self.get_commit(self.newbase)
101+
self.parse_commits(newbase_commit)
99102
head_commit = self.get_commit()
100103
default_commits = {}
101104
self.get_default_commits(self.get_commit(), default_commits)
102105
flat_default_commits = self.sort_and_flatten(default_commits)
103106

104107
self.parse_commits(head_commit, shift=4 * m.DOWN)
105108
self.parse_all()
106-
self.center_frame_on_commit(branch_commit)
109+
self.center_frame_on_commit(newbase_commit)
107110

108111
self.to_rebase = []
109112
for c in flat_default_commits:
110-
if self.branch not in self.repo.git.branch("--contains", c):
113+
if not self.repo.is_ancestor(c, self.newbase):
111114
if self.onto and self.until:
112-
range_commits = list(self.repo.iter_commits(f"{self.oldparent}...{self.until}"))
115+
range_commits = list(self.repo.iter_commits(f"{self.oldbase}...{self.until}"))
113116
if c in range_commits:
114117
self.to_rebase.append(c)
115118
else:
116119
self.to_rebase.append(c)
120+
if self.rebase_merges:
121+
if len(self.to_rebase[-1].parents) > 1:
122+
self.cleaned_to_rebase = []
123+
for j, tr in enumerate(self.to_rebase):
124+
if self.repo.is_ancestor(self.to_rebase[-1], tr):
125+
self.cleaned_to_rebase.append(tr)
126+
self.to_rebase = self.cleaned_to_rebase
117127

118128
reached_base = False
119-
merge_base = self.repo.git.merge_base(self.branch, self.repo.active_branch.name)
129+
merge_base = self.repo.git.merge_base(self.newbase, self.repo.head.commit.hexsha)
120130
if merge_base in self.drawnCommits or (self.onto and self.to_rebase[-1].hexsha in self.drawnCommits):
121131
reached_base = True
122132

123-
parent = branch_commit.hexsha
133+
parent = newbase_commit.hexsha
124134
branch_counts = {}
125135
rebased_shas = []
126136
rebased_sha_map = {}
@@ -159,15 +169,20 @@ def construct(self):
159169
arrow_color = self.alt_colors[color_index % len(self.alt_colors)][1 if branch_counts[color_index] % 2 == 0 else 1]
160170
self.draw_arrow_between_commits(tr.hexsha, rebased_shas[j - k], color=arrow_color)
161171

162-
if self.rebase_merges:
163-
if self.onto and self.until:
164-
until_sha = self.get_commit(self.until).hexsha
165-
if until_sha == self.repo.head.commit.hexsha:
166-
self.reset_head_branch(rebased_sha_map[until_sha])
167-
else:
168-
self.reset_head(rebased_sha_map[until_sha])
172+
if self.onto and self.until:
173+
until_sha = self.get_commit(self.until).hexsha
174+
if until_sha == self.repo.head.commit.hexsha:
175+
self.reset_head_branch(rebased_sha_map[until_sha])
169176
else:
170-
self.reset_head_branch(rebased_sha_map[default_commits[0][0].hexsha])
177+
try:
178+
self.reset_head(rebased_sha_map[until_sha])
179+
except KeyError:
180+
for sha in rebased_sha_map:
181+
if len(self.get_commit(sha).parents) < 2:
182+
self.reset_head(rebased_sha_map[sha])
183+
break
184+
elif self.rebase_merges:
185+
self.reset_head_branch(rebased_sha_map[default_commits[0][0].hexsha])
171186
else:
172187
self.reset_head_branch(parent)
173188
self.color_by(offset=2 * len(self.to_rebase))
@@ -223,14 +238,21 @@ def setup_and_draw_parent(
223238
end = tuple(self.drawnCommits[child].get_center())
224239
arrow_start_ends.add((start, end))
225240
if self.rebase_merges:
226-
for p in self.get_commit(orig).parents:
227-
if self.branch in self.repo.git.branch(
228-
"--contains", p
229-
):
241+
orig_commit = self.get_commit(orig)
242+
if len(orig_commit.parents) < 2:
243+
self.non_merge_reached = True
244+
for p in orig_commit.parents:
245+
if self.repo.is_ancestor(p, self.newbase):
230246
continue
231247
try:
232248
if p not in self.to_rebase:
233-
end = tuple(self.drawnCommits[self.get_commit(self.branch).hexsha].get_center())
249+
if branch_index > 0:
250+
end = tuple(self.drawnCommits[self.get_commit(self.newbase).hexsha].get_center())
251+
elif not self.non_merge_reached:
252+
if p.hexsha == orig_commit.parents[1].hexsha:
253+
end = tuple(self.drawnCommits[p.hexsha].get_center())
254+
else:
255+
continue
234256
else:
235257
end = tuple(self.drawnCommits[p.hexsha].get_center() + m.UP * 4 + (m.LEFT if settings.reverse else m.RIGHT) * len(default_commits[0]) * 2.5 + (m.LEFT * side_offset if settings.reverse else m.RIGHT * side_offset) * 5)
236258
arrow_start_ends.add((start, end))
@@ -303,11 +325,9 @@ def get_default_commits(self, commit, default_commits, branch_index=0):
303325
if branch_index not in default_commits:
304326
default_commits[branch_index] = []
305327
if len(default_commits[branch_index]) < self.n:
306-
if self.onto and commit.hexsha == self.get_commit(self.oldparent).hexsha:
328+
if self.onto and commit.hexsha == self.get_commit(self.oldbase).hexsha:
307329
return default_commits
308-
if commit not in self.sort_and_flatten(default_commits) and self.branch not in self.repo.git.branch(
309-
"--contains", commit
310-
):
330+
if commit not in self.sort_and_flatten(default_commits) and not self.repo.is_ancestor(commit, self.newbase):
311331
default_commits[branch_index].append(commit)
312332
for i, parent in enumerate(commit.parents):
313333
self.get_default_commits(parent, default_commits, branch_index + i)

0 commit comments

Comments
 (0)