diff --git a/local-app/python-tools/cross-organization/tag-checker.py b/local-app/python-tools/cross-organization/tag-checker.py index 64c677b9..aa87aa0b 100755 --- a/local-app/python-tools/cross-organization/tag-checker.py +++ b/local-app/python-tools/cross-organization/tag-checker.py @@ -6,11 +6,14 @@ import argparse import sys import time +import re +import os from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed from botocore.exceptions import ClientError from tqdm import tqdm -__version__ = "1.0.6" +__version__ = "1.1.0" def get_args(): parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}") @@ -18,152 +21,152 @@ def get_args(): parser.add_argument("--region", required=True, help="Primary region for API initialization") parser.add_argument("--profile", required=True, help="AWS CLI profile for Management Account") parser.add_argument("--tags-file", required=True, help="CSV file with Tag Key in the first column") - parser.add_argument("--limit", type=int, default=0, help="Limit scan to X number of accounts (0 for no limit)") + parser.add_argument("--max-workers", type=int, default=8, help="Max concurrent account scans (default: 8)") + parser.add_argument("--account-regex", help="Regex to filter accounts by alias/name") + parser.add_argument("--accounts-from", help="File containing specific Account IDs to process") + parser.add_argument("--output", default="tag_checker_findings", help="Prefix for output files") + parser.add_argument("--limit", type=int, default=0, help="Hard limit on total accounts to scan") return parser.parse_args() -def get_session_for_account(management_session, account_id, role_name, partition): +def get_session(management_session, account_id, role_name, partition): sts = management_session.client('sts') role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}" try: - response = sts.assume_role( - RoleArn=role_arn, - RoleSessionName="TagDiscoveryScanner" - ) - creds = response['Credentials'] - return boto3.Session( - aws_access_key_id=creds['AccessKeyId'], - aws_secret_access_key=creds['SecretAccessKey'], - aws_session_token=creds['SessionToken'] - ) - except ClientError: - return None - -def get_account_alias(session): + response = sts.assume_role(RoleArn=role_arn, RoleSessionName="TagDiscoveryScanner") + c = response['Credentials'] + return boto3.Session(aws_access_key_id=c['AccessKeyId'], + aws_secret_access_key=c['SecretAccessKey'], + aws_session_token=c['SessionToken']) + except: return None + +def get_alias_fixed(session): + """Accurately retrieves the IAM alias from the member account.""" try: - iam = session.client('iam') - aliases = iam.list_account_aliases()['AccountAliases'] - return aliases[0] if aliases else "No Alias" - except ClientError: - return "Unknown" + return session.client('iam').list_account_aliases()['AccountAliases'][0] + except: return "No Alias" -def main(): - args = get_args() - start_time_overall = time.time() +def scan_account(account, management_session, role_name, partition, tag_keys, region_name, position): + acc_id, acc_name = account['Id'], account['Name'] + m_session = get_session(management_session, acc_id, role_name, partition) - try: - session = boto3.Session(profile_name=args.profile) - org_client = session.client('organizations') - partition = session.client('sts').get_caller_identity()['Arn'].split(':')[1] - except Exception as e: - print(f"[!] Initialization Error: {e}") - sys.exit(1) + if not m_session: + return [], acc_id, "AUTH_FAIL", 0 + + alias = get_alias_fixed(m_session) + findings = [] + + # Inner progress bar (The "Lane") + pbar = tqdm(total=len(tag_keys), desc=f"Lane {position}: {acc_id}", position=position, leave=False) - # Load Tag Keys - tag_keys = [] try: - with open(args.tags_file, mode='r', encoding='utf-8-sig') as f: - reader = csv.reader(f) - next(reader) - tag_keys = [row[0].strip().replace('"', '') for row in reader if row] - except Exception as e: - print(f"[!] File Error: {e}") - sys.exit(1) - - # Get Account List - all_accounts = [] - paginator = org_client.get_paginator('list_accounts') - for page in paginator.paginate(): - for account in page['Accounts']: - if account['Status'] == 'ACTIVE': - all_accounts.append(account) + ec2 = m_session.client('ec2', region_name=region_name) + regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']] + except: regions = [region_name] + + for key in tag_keys: + for r in regions: + client = m_session.client('resourcegroupstaggingapi', region_name=r) + try: + paginator = client.get_paginator('get_resources') + for page in paginator.paginate(TagFilters=[{'Key': key}]): + for mapping in page.get('ResourceTagMappingList', []): + val = next((t['Value'] for t in mapping['Tags'] if t['Key'] == key), "N/A") + findings.append({ + "tag_name": key, "tag_value": val, "account_id": acc_id, + "account_alias": alias, "region": r, "arn": mapping['ResourceARN'] + }) + except ClientError as e: + if "Throttling" in str(e): time.sleep(1) + pbar.update(1) - if args.limit > 0: - all_accounts = all_accounts[:args.limit] + pbar.close() + return findings, acc_id, alias, len(findings) - print(f"[*] Starting Scan v{__version__} | Tags: {len(tag_keys)} | Accounts: {len(all_accounts)}") +def main(): + args = get_args() + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + start_time = time.time() - all_results = [] - account_metrics = [] - - # Overall Progress Bar (Position 0) - outer_pbar = tqdm(total=len(all_accounts), desc="Overall Progress", position=0) + # 1. Initialization + session = boto3.Session(profile_name=args.profile) + org = session.client('organizations') + partition = session.client('sts').get_caller_identity()['Arn'].split(':')[1] + + # 2. Pre-scan Summary + print(f"\n{'='*50}\nAWS TAG CHECKER v{__version__}\n{'='*50}") + print(f"Profile: {args.profile} | Region: {args.region} | Role: {args.role_name}") + + # 3. Load Tags and Account Filters + with open(args.tags_file, mode='r', encoding='utf-8-sig') as f: + tag_keys = [row[0].strip().replace('"', '') for row in list(csv.reader(f))[1:] if row] + + allowed_ids = [] + if args.accounts_from: + with open(args.accounts_from, 'r') as f: + allowed_ids = [line.strip() for line in f if line.strip()] + + # 4. Fetch and Filter Accounts + all_raw_accounts = [] + paginator = org.get_paginator('list_accounts') + for page in paginator.paginate(): + all_raw_accounts.extend(page['Accounts']) - for account in all_accounts: - acc_start_time = time.time() - acc_id, acc_name = account['Id'], account['Name'] - - # Update description to show current account - outer_pbar.set_description(f"Scanning {acc_id}") + to_process = [] + for acc in all_raw_accounts: + if acc['Status'] != 'ACTIVE': continue + if allowed_ids and acc['Id'] not in allowed_ids: continue - m_session = get_session_for_account(session, acc_id, args.role_name, partition) - if not m_session: - outer_pbar.write(f"[!] Skipping {acc_id}: Cannot assume role") - outer_pbar.update(1) - continue - - alias = get_account_alias(m_session) - hits_in_account = 0 - - try: - ec2 = m_session.client('ec2', region_name=args.region) - regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']] - except ClientError: - regions = [args.region] - - # Inner Progress Bar (Position 1) - One per account - inner_pbar = tqdm(total=len(tag_keys), desc=f"Tags in {acc_id}", position=1, leave=False) - - for target_key in tag_keys: - inner_pbar.set_description(f"Checking: {target_key[:20]}") - for region in regions: - tag_client = m_session.client('resourcegroupstaggingapi', region_name=region) - tag_paginator = tag_client.get_paginator('get_resources') - - try: - for tag_page in tag_paginator.paginate(TagFilters=[{'Key': target_key}]): - for r_mapping in tag_page.get('ResourceTagMappingList', []): - hits_in_account += 1 - val = next((t['Value'] for t in r_mapping['Tags'] if t['Key'] == target_key), "N/A") - all_results.append({ - "tag_name": target_key, - "tag_value": val, - "account_id": acc_id, - "account_alias": alias, - "region": region, - "arn": r_mapping['ResourceARN'] - }) - except ClientError as e: - if "Throttling" in str(e): time.sleep(1) - inner_pbar.update(1) + # Temp check for Regex (Requires Alias which we fetch inside threads, + # so here we check Org Name first, then Alias in-thread) + if args.account_regex and not re.search(args.account_regex, acc['Name']): + # We'll re-verify Alias inside the thread for regex accuracy + pass + to_process.append(acc) + + if args.limit > 0: to_process = to_process[:args.limit] + print(f"Accounts Found: {len(all_raw_accounts)} | Targeted: {len(to_process)}") + print(f"{'='*50}\n") + + # 5. Multi-threaded Execution + all_findings = [] + summary_data = [] + + # Overall Progress Bar + overall_pbar = tqdm(total=len(to_process), desc="Total Progress", position=0) + + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + # Use a map to track positions (lanes 1 through max_workers) + futures = {executor.submit(scan_account, acc, session, args.role_name, partition, + tag_keys, args.region, (i % args.max_workers) + 1): acc + for i, acc in enumerate(to_process)} - inner_pbar.close() - acc_elapsed = time.time() - acc_start_time - account_metrics.append({ - "account_id": acc_id, "account_name": acc_name, - "elapsed": round(acc_elapsed, 2), "hits": hits_in_account - }) - outer_pbar.update(1) - - outer_pbar.close() - - # Final Output and Summary - total_elapsed = time.time() - start_time_overall - summary = { - "version": __version__, - "total_accounts": len(all_accounts), - "total_hits": len(all_results), - "total_time_seconds": round(total_elapsed, 2), - "accounts": account_metrics - } - - with open('summary_metrics.json', 'w') as f: json.dump(summary, f, indent=4) - if all_results: - with open('findings.csv', 'w', newline='') as f: - writer = csv.DictWriter(f, fieldnames=all_results[0].keys()) + for future in as_completed(futures): + res, acc_id, alias, count = future.result() + + # Post-thread Regex filtering for Alias accuracy + if args.account_regex and not re.search(args.account_regex, alias): + overall_pbar.write(f"[-] Skipped {acc_id} ({alias}): Regex mismatch") + else: + all_findings.extend(res) + summary_data.append({"account_id": acc_id, "alias": alias, "hits": count}) + + overall_pbar.update(1) + + overall_pbar.close() + + # 6. Output Files + csv_file = f"{args.output}_{ts}.csv" + json_file = f"{args.output}_{ts}.json" + + with open(json_file, 'w') as f: json.dump(all_findings, f, indent=4) + if all_findings: + with open(csv_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=all_findings[0].keys()) writer.writeheader() - writer.writerows(all_results) + writer.writerows(all_findings) - print(f"\n[+] Scan Complete. Found {len(all_results)} hits.") + print(f"\n[+] Done! Scanned {len(summary_data)} accounts in {round(time.time()-start_time, 2)}s") + print(f"[+] Findings: {csv_file}") if __name__ == "__main__": main()