diff --git a/eks_automation/app.py b/eks_automation/app.py index c692b4d..6da7e8e 100644 --- a/eks_automation/app.py +++ b/eks_automation/app.py @@ -484,12 +484,12 @@ def commit_repository_contents(self, repo_name, work_dir, commit_message, branch "sha": blob_sha }) - # Try to get the latest commit SHA for the branch - # If it doesn't exist, we'll create it + # Try to get the latest commit SHA from the base branch + base_branch = "main" # Always use main as base when creating new branches try: - latest_commit_sha = self.get_reference_sha(repo_name, f"heads/{target_branch}") - latest_commit = self.get_commit(repo_name, latest_commit_sha) - base_tree_sha = latest_commit["tree"]["sha"] + base_commit_sha = self.get_reference_sha(repo_name, f"heads/{base_branch}") + base_commit = self.get_commit(repo_name, base_commit_sha) + base_tree_sha = base_commit["tree"]["sha"] except Exception: # If we can't get the reference, assume it's a new repo with no commits base_tree_sha = None @@ -504,7 +504,7 @@ def commit_repository_contents(self, repo_name, work_dir, commit_message, branch repo_name, commit_message, new_tree_sha, - [latest_commit_sha] + [base_commit_sha] ) else: # If it's a new repo, create the first commit @@ -517,18 +517,25 @@ def commit_repository_contents(self, repo_name, work_dir, commit_message, branch # Update or create the reference to point to the new commit try: + # Try to update existing branch self.update_reference( repo_name, f"heads/{target_branch}", new_commit_sha ) except Exception: - # If the reference doesn't exist, create it - self.create_reference( - repo_name, - f"refs/heads/{target_branch}", - new_commit_sha - ) + # If the branch doesn't exist, create it + try: + self.create_reference( + repo_name, + f"refs/heads/{target_branch}", + new_commit_sha + ) + except Exception as e: + # If we still can't create the branch, something is wrong + error_message = f"Failed to create or update branch {target_branch} for {repo_name}: {str(e)}" + logger.error(error_message) + raise Exception(error_message) return target_branch diff --git a/eks_automation/tests/test_github_client_integration.py b/eks_automation/tests/test_github_client_integration.py index afe0587..7b24c33 100644 --- a/eks_automation/tests/test_github_client_integration.py +++ b/eks_automation/tests/test_github_client_integration.py @@ -126,6 +126,7 @@ def test_branch_operations(self, temp_repo_name, cleanup_repo): # Create a test file in main branch with tempfile.TemporaryDirectory() as work_dir: + # Initial commit on main branch main_file = os.path.join(work_dir, "test.txt") with open(main_file, "w") as f: f.write("main branch content") @@ -136,10 +137,15 @@ def test_branch_operations(self, temp_repo_name, cleanup_repo): "Initial commit on main" ) - # Create and switch to a new branch + # Create and switch to a test branch test_branch = "test-branch" - shutil.rmtree(work_dir) - os.makedirs(work_dir) + # Clean directory for test branch changes + for file in os.listdir(work_dir): + file_path = os.path.join(work_dir, file) + if os.path.isfile(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) # Create different content in test branch with open(main_file, "w") as f: @@ -152,14 +158,18 @@ def test_branch_operations(self, temp_repo_name, cleanup_repo): branch=test_branch ) - # Verify main branch content - main_output = os.path.join(work_dir, "main") + # Clone and verify main branch content + main_output = os.path.join(work_dir, "clone-main") + os.makedirs(main_output, exist_ok=True) self.client.clone_repository_contents(repo_name, main_output, branch="main") + with open(os.path.join(main_output, "test.txt")) as f: - assert f.read() == "main branch content" + assert f.read().strip() == "main branch content" - # Verify test branch content - test_output = os.path.join(work_dir, "test") + # Clone and verify test branch content + test_output = os.path.join(work_dir, "clone-test") + os.makedirs(test_output, exist_ok=True) self.client.clone_repository_contents(repo_name, test_output, branch=test_branch) + with open(os.path.join(test_output, "test.txt")) as f: - assert f.read() == "test branch content" + assert f.read().strip() == "test branch content"