@@ -66,15 +66,75 @@ def _depth_default(self):
6666 where the GitPuller class hadn't been loaded already."""
6767 return int (os .environ .get ('NBGITPULLER_DEPTH' , 1 ))
6868
69- def __init__ (self , git_url , branch_name , repo_dir , ** kwargs ):
70- assert git_url and branch_name
69+ def __init__ (self , git_url , repo_dir , ** kwargs ):
70+ assert git_url
7171
7272 self .git_url = git_url
73- self .branch_name = branch_name
73+ self .branch_name = kwargs .pop ("branch" )
74+
75+ if self .branch_name is None :
76+ self .branch_name = self .resolve_default_branch ()
77+ elif not self .branch_exists (self .branch_name ):
78+ raise ValueError (f"Branch: { self .branch_name } -- not found in repo: { self .git_url } " )
79+
7480 self .repo_dir = repo_dir
7581 newargs = {k : v for k , v in kwargs .items () if v is not None }
7682 super (GitPuller , self ).__init__ (** newargs )
7783
84+ def branch_exists (self , branch ):
85+ """
86+ This checks to make sure the branch we are told to access
87+ exists in the repo
88+ """
89+ try :
90+ heads = subprocess .run (
91+ ["git" , "ls-remote" , "--heads" , self .git_url ],
92+ capture_output = True ,
93+ text = True ,
94+ check = True
95+ )
96+ tags = subprocess .run (
97+ ["git" , "ls-remote" , "--tags" , self .git_url ],
98+ capture_output = True ,
99+ text = True ,
100+ check = True
101+ )
102+ lines = heads .stdout .splitlines () + tags .stdout .splitlines ()
103+ branches = []
104+ for line in lines :
105+ _ , ref = line .split ()
106+ refs , heads , branch_name = ref .split ("/" , 2 )
107+ branches .append (branch_name )
108+ return branch in branches
109+ except subprocess .CalledProcessError :
110+ m = f"Problem accessing list of branches and/or tags: { self .git_url } "
111+ logging .exception (m )
112+ raise ValueError (m )
113+
114+ def resolve_default_branch (self ):
115+ """
116+ This will resolve the default branch of the repo in
117+ the case where the branch given does not exist
118+ """
119+ try :
120+ head_branch = subprocess .run (
121+ ["git" , "ls-remote" , "--symref" , self .git_url , "HEAD" ],
122+ capture_output = True ,
123+ text = True ,
124+ check = True
125+ )
126+ for line in head_branch .stdout .splitlines ():
127+ if line .startswith ("ref:" ):
128+ # line resembles --> ref: refs/heads/main HEAD
129+ _ , ref , head = line .split ()
130+ refs , heads , branch_name = ref .split ("/" , 2 )
131+ return branch_name
132+ raise ValueError (f"default branch not found in { self .git_url } " )
133+ except subprocess .CalledProcessError :
134+ m = f"Problem accessing HEAD branch: { self .git_url } "
135+ logging .exception (m )
136+ raise ValueError (m )
137+
78138 def pull (self ):
79139 """
80140 Pull selected repo from a remote git repository,
@@ -243,13 +303,11 @@ def main():
243303
244304 parser = argparse .ArgumentParser (description = 'Synchronizes a github repository with a local repository.' )
245305 parser .add_argument ('git_url' , help = 'Url of the repo to sync' )
246- parser .add_argument ('branch_name' , default = 'master' , help = 'Branch of repo to sync' , nargs = '?' )
247306 parser .add_argument ('repo_dir' , default = '.' , help = 'Path to clone repo under' , nargs = '?' )
248307 args = parser .parse_args ()
249308
250309 for line in GitPuller (
251310 args .git_url ,
252- args .branch_name ,
253311 args .repo_dir
254312 ).pull ():
255313 print (line )
0 commit comments