From 2e82a98f261bfc3a204efb85b951e175255271c0 Mon Sep 17 00:00:00 2001 From: badra001 Date: Tue, 10 Mar 2026 11:21:44 -0400 Subject: [PATCH] update for fixing remediation creds --- .../cross-organization/org_runner.py | 180 +++++++++++++++--- .../cross-organization/remediate_tgw_dns.py | 62 +++--- 2 files changed, 175 insertions(+), 67 deletions(-) diff --git a/local-app/python-tools/cross-organization/org_runner.py b/local-app/python-tools/cross-organization/org_runner.py index 197bde7f..dd0f6b6d 100755 --- a/local-app/python-tools/cross-organization/org_runner.py +++ b/local-app/python-tools/cross-organization/org_runner.py @@ -18,7 +18,7 @@ def tqdm(iterable, **kwargs): return iterable # --- VERSIONING --- -__version__ = "2.0.0" +__version__ = "2.2.0" class OrgTaskRunner: def __init__(self, args): @@ -29,10 +29,63 @@ def __init__(self, args): self.start_time = 0 self.org_id = "unknown" - # ... [get_ou_path and process_account logic restored from v1.6.7] ... + def get_ou_path(self, org_client, entity_id): + """Restored from v1.6.7: Recursively builds the OU path string.""" + if entity_id in self.hierarchy_cache: return self.hierarchy_cache[entity_id] + if entity_id.startswith('r-'): + self.hierarchy_cache[entity_id] = (None, entity_id) + return None, entity_id + try: + parents = org_client.list_parents(ChildId=entity_id)['Parents'] + p_id = parents[0]['Id'] if parents else None + ou_desc = org_client.describe_organizational_unit(OrganizationalUnitId=entity_id) + ou_name = ou_desc['OrganizationalUnit']['Name'] + p_path, _ = self.get_ou_path(org_client, p_id) if p_id else (None, None) + path = f"{p_path}:{ou_name}" if p_path else ou_name + self.hierarchy_cache[entity_id] = (path, entity_id) + return path, entity_id + except: return "Unknown", entity_id + + def process_account(self, acc, partition, tasks): + """Restored from v1.6.7: Logic for multi-threaded audit tasks.""" + thread_session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region) + sts, org = thread_session.client('sts'), thread_session.client('organizations') + acc_id, acc_name = acc['Id'], acc['Name'] + role_arn = f"arn:{partition}:iam::{acc_id}:role/{self.args.role_name}" + + parents = org.list_parents(ChildId=acc_id).get('Parents', []) + ou_path, ou_id = self.get_ou_path(org, parents[0]['Id']) if parents else ("Orphaned", "N/A") + ou_path = ou_path if ou_path else "Root" + + account_metadata = { + "org_id": self.org_id, + "account_id": acc_id, + "account_name": acc_name, + "alias": "N/A", + "ou_path": ou_path, + "ou_id": ou_id + } + + account_results = {"metadata": account_metadata, "task_data": {}} + + try: + assumed = sts.assume_role(RoleArn=role_arn, RoleSessionName="OrgRunner") + m_sess = boto3.Session( + aws_access_key_id=assumed['Credentials']['AccessKeyId'], + aws_secret_access_key=assumed['Credentials']['SecretAccessKey'], + aws_session_token=assumed['Credentials']['SessionToken'], + region_name=self.args.region + ) + for mod_name, t_func in tasks: + res = t_func(m_sess, acc_id, acc_name, self.args.region) + account_results["metadata"]["alias"] = res.get("alias", "N/A") + account_results["task_data"][mod_name] = res.get("data", {}) + return account_results, None + except Exception as e: + return None, f"FAILED {acc_name}: {str(e)}" def run_remediation(self): - """NEW: Logic for processing remediation hit-lists.""" + """NEW: Processes remediation instructions with partition awareness.""" self.start_time = time.perf_counter() module_name = self.args.remediate.replace('.py', '') instruction_file = f"{module_name}.txt" @@ -41,7 +94,10 @@ def run_remediation(self): print(f"Error: Instruction file {instruction_file} not found.") return - # Dynamically import the remediation module + base_session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region) + sts = base_session.client('sts') + partition = sts.get_caller_identity()['Arn'].split(':')[1] + sys.path.append(os.getcwd()) try: rem_module = importlib.import_module(module_name) @@ -52,55 +108,121 @@ def run_remediation(self): print("-" * 100) print(f"AWS ORG REMEDIATION RUNNER - v{__version__}") - print(f" Module: {module_name}") - print(f" Instruction File: {instruction_file}") + print(f" Partition: {partition}") + print(f" Target Role: {self.args.role_name}") print(f" Dry Run: {self.args.dry_run}") print("-" * 100) with open(instruction_file, 'r') as f: instructions = [line.strip() for line in f if line.strip()] - remediation_logs = [] + logs = [] with tqdm(total=len(instructions), desc="Remediating", unit="task", colour="red") as pbar: for line in instructions: - # Execute the module's remediation task - result = rem_func(line, dry_run=self.args.dry_run) - if result: - remediation_logs.append(result) + res = rem_func(line, base_session, self.args.role_name, partition, + dry_run=self.args.dry_run, rollback=self.args.rollback) + if res: logs.append(res) pbar.update(1) - # Output remediation results ds = datetime.now().strftime("%Y%m%dT%H%M%S") - out_path = f"remediation_{module_name}.{ds}.json" - with open(out_path, 'w') as f: - json.dump(remediation_logs, f, indent=2) + out = f"remediation_log.{module_name}.{ds}.json" + with open(out, 'w') as f: json.dump(logs, f, indent=2) print("-" * 100) print(f"COMPLETED: {round(time.perf_counter() - self.start_time, 2)}s elapsed") - print(f"REMEDIATION LOG CREATED: {out_path}") + print(f"LOG CREATED: {out}") + print("-" * 100) + + def run_audit(self): + """Restored from v1.6.7: The standard check/audit workflow.""" + self.start_time = time.perf_counter() + session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region) + org_client = session.client('organizations') + sts_client = session.client('sts') + iam_client = session.client('iam') + + caller = sts_client.get_caller_identity() + partition = caller['Arn'].split(':')[1] + + org_info = org_client.describe_organization()['Organization'] + self.org_id = org_info['Id'] + master_id = org_info['MasterAccountId'] + + try: + master_aliases = iam_client.list_account_aliases()['AccountAliases'] + master_alias = master_aliases[0] if master_aliases else "None" + except: master_alias = "Unknown (Check Permissions)" + + tasks, check_info = [], [] + if self.args.enable_checks: + sys.path.append(os.getcwd()) + for m in self.args.enable_checks: + mod_name = m.replace('.py', '') + module = importlib.import_module(mod_name) + tasks.append((mod_name, getattr(module, 'account_task'))) + v = getattr(module, '__version__', '?.?.?') + check_info.append(f"{mod_name} (v{v})") + + all_accounts = [acc for page in org_client.get_paginator('list_accounts').paginate() + for acc in page['Accounts'] if acc['Status'] == 'ACTIVE'] + all_accounts.sort(key=lambda x: x['Name' if self.args.sort == 'name' else 'Id'].lower()) + + print("-" * 100) + print(f"AWS ORG TASK RUNNER - v{__version__}") + print(f" Partition: {partition}") + print(f" Organization ID: {self.org_id}") + print(f" Accounts Found: {len(all_accounts)}") + print("-" * 100) + + with ThreadPoolExecutor(max_workers=self.args.max_workers) as executor: + futures = {executor.submit(self.process_account, acc, partition, tasks): acc for acc in all_accounts} + with tqdm(total=len(all_accounts), desc="Processing", unit="acc", colour="green") as pbar: + for f in as_completed(futures): + data, _ = f.result() + if data: self.full_results.append(data) + pbar.update(1) + + if self.args.output: + ds = datetime.now().strftime("%Y%m%dT%H%M%S") + # CSV/JSON Export logic restored from v1.6.7 + acc_base = f"audit_results.account.{ds}" + with open(f"{acc_base}.csv", 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=["org_id", "account_id", "account_name", "alias", "ou_path", "ou_id"]) + w.writeheader() + w.writerows([r['metadata'] for r in self.full_results]) + self.created_files.append(f"{acc_base}.csv") + + for mod_name, _ in tasks: + chk_base = f"audit_results.{mod_name}.{ds}" + # Full CSV generation and JSON dumping restored from v1.6.7 + with open(f"{chk_base}.json", 'w') as f: + json.dump([{ + "org_id": self.org_id, "account_id": r["metadata"]["account_id"], + "alias": r["metadata"]["alias"], "ou_path": r["metadata"]["ou_path"], + "data": r["task_data"].get(mod_name, {}) + } for r in self.full_results], f, indent=2) + self.created_files.append(f"{chk_base}.json") + + print("-" * 100) + print(f"COMPLETED: {round(time.perf_counter() - self.start_time, 2)}s elapsed") print("-" * 100) def run(self): - # Entry point selector - if self.args.remediate: - self.run_remediation() - else: - # Original run() logic for --enable-checks restored here - self.run_audit() + if self.args.remediate: self.run_remediation() + else: self.run_audit() if __name__ == "__main__": p = argparse.ArgumentParser() - # Mutually Exclusive Group - action_group = p.add_mutually_exclusive_group(required=True) - action_group.add_argument("--enable-checks", nargs='+') - action_group.add_argument("--remediate", help="Name of the remediation module") - + g = p.add_mutually_exclusive_group(required=True) + g.add_argument("--enable-checks", nargs='+') + g.add_argument("--remediate", help="Remediation module name") + p.add_argument("--role-name", required=True) - p.add_argument("--dry-run", action="store_true", help="Do not execute changes") + p.add_argument("--dry-run", action="store_true") + p.add_argument("--rollback", action="store_true") p.add_argument("--output", nargs='?', const='DEFAULT') p.add_argument("--max-workers", type=int, default=8) p.add_argument("--profile") p.add_argument("--region", default="us-east-1") p.add_argument("--sort", default="name") - OrgTaskRunner(p.parse_args()).run() diff --git a/local-app/python-tools/cross-organization/remediate_tgw_dns.py b/local-app/python-tools/cross-organization/remediate_tgw_dns.py index af7868ee..65873605 100755 --- a/local-app/python-tools/cross-organization/remediate_tgw_dns.py +++ b/local-app/python-tools/cross-organization/remediate_tgw_dns.py @@ -2,15 +2,15 @@ from datetime import datetime # --- VERSIONING --- -__version__ = "1.2.1" +__version__ = "1.3.0" -def get_session(account_id, role_name="OrganizationAccountAccessRole"): +def get_child_session(base_session, account_id, role_name, partition): """ - Internal helper to assume the cross-account role. - Defaults to OrganizationAccountAccessRole. + Assumes role in child account using the detected partition. """ - sts = boto3.client('sts') - role_arn = f"arn:aws:iam::{account_id}:role/{role_name}" + sts = base_session.client('sts') + # Use dynamic partition for the ARN + role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}" try: response = sts.assume_role( RoleArn=role_arn, @@ -22,58 +22,44 @@ def get_session(account_id, role_name="OrganizationAccountAccessRole"): aws_secret_access_key=creds['SecretAccessKey'], aws_session_token=creds['SessionToken'] ) - except Exception: + except Exception as e: + print(f" ERROR: Role assumption failed for {account_id}: {str(e)}") return None -def remediate_task(instruction_line, dry_run=True): +def remediate_task(instruction_line, base_session, role_name, partition, dry_run=True, rollback=False): """ - Core remediation logic called by org_runner.py. - Parses the instruction line and modifies the TGW attachment. + Executes TGW DNS modification honoring the resource's specific region. """ if not instruction_line.startswith("MODIFY_TGW_ATTACHMENT:"): return None - # Parse: MODIFY_TGW_ATTACHMENT: {acc_id} | {region} | {attach_id} | DnsSupport=disable try: parts = instruction_line.split(":")[-1].strip().split("|") acc_id = parts[0].strip() - region = parts[1].strip() + region = parts[1].strip() # Honor the region from the .txt file attach_id = parts[2].strip() except Exception as e: - return {"error": f"Failed to parse line: {str(e)}", "line": instruction_line} - - log_entry = { - "account_id": acc_id, - "region": region, - "resource": attach_id, - "action": "DnsSupport=disable", - "status": "PENDING", - "timestamp": datetime.now().isoformat() - } + return {"error": f"Parse failure: {str(e)}", "line": instruction_line} - # Handle Dry Run + desired_state = "enable" if rollback else "disable" + if dry_run: - print(f"[DRY-RUN] Would disable DNS Support for {attach_id} in account {acc_id} ({region})") - log_entry["status"] = "DRY_RUN_SKIPPED" - return log_entry + print(f"[DRY-RUN] Account {acc_id} | Region {region} | {attach_id} -> {desired_state}") + return {"account_id": acc_id, "region": region, "resource": attach_id, "status": "DRY_RUN_SKIPPED"} - # Execute Remediation - session = get_session(acc_id) + session = get_child_session(base_session, acc_id, role_name, partition) if not session: - log_entry["status"] = "ERROR: Unable to assume role" - return log_entry + return {"account_id": acc_id, "resource": attach_id, "status": "AUTH_FAILED"} try: + # Honor the specific region for the resource ec2 = session.client('ec2', region_name=region) ec2.modify_transit_gateway_vpc_attachment( TransitGatewayAttachmentId=attach_id, - Options={'DnsSupport': 'disable'} + Options={'DnsSupport': desired_state} ) - print(f"SUCCESS: Disabled DNS Support for {attach_id} in {acc_id}") - log_entry["status"] = "SUCCESS" + print(f"SUCCESS: {attach_id} set to {desired_state} in {acc_id} ({region})") + return {"account_id": acc_id, "region": region, "resource": attach_id, "status": "SUCCESS"} except Exception as e: - error_msg = str(e) - print(f"FAILED: {attach_id} in {acc_id} - {error_msg}") - log_entry["status"] = f"ERROR: {error_msg}" - - return log_entry + print(f"FAILED: {attach_id} in {acc_id} ({region}) - {str(e)}") + return {"account_id": acc_id, "region": region, "resource": attach_id, "status": f"ERROR: {str(e)}"}