diff --git a/local-app/python-tools/cross-organization/tag-checker.py b/local-app/python-tools/cross-organization/tag-checker.py index 620888da..756a67ea 100755 --- a/local-app/python-tools/cross-organization/tag-checker.py +++ b/local-app/python-tools/cross-organization/tag-checker.py @@ -9,38 +9,33 @@ import re import os import resource -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from concurrent.futures import ThreadPoolExecutor, as_completed from botocore.exceptions import ClientError from botocore.credentials import RefreshableCredentials from botocore.session import get_session as get_botocore_session from tqdm import tqdm -__version__ = "1.1.14" +__version__ = "1.1.15" def get_args(): parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}") - parser.add_argument("--role-name", required=True, help="Role to assume in member accounts") + parser.add_argument("--role-name", required=False, help="Role to assume in member accounts") parser.add_argument("--region", required=True, help="Management account region (e.g., us-gov-east-1)") parser.add_argument("--profile", required=True, help="AWS CLI profile for Management Account") - parser.add_argument("--tags-file", required=True, help="CSV file with TagKey, Status, Type, etc.") + parser.add_argument("--tags-file", required=False, help="CSV file with TagKey, Type, Status, etc.") parser.add_argument("--max-workers", type=int, default=8, help="Max concurrent account scans") parser.add_argument("--account-regex", help="Regex to filter accounts by alias") parser.add_argument("--accounts-from", help="File of Account IDs to process") parser.add_argument("--output", default="tag_checker", help="Prefix for output files") parser.add_argument("--limit", type=int, default=0, help="Limit total accounts processed") parser.add_argument("--verbose", action="store_true", help="Enable detailed logging") + parser.add_argument("--list-accounts", action="store_true", help="List all unique Account IDs and exit") return parser.parse_args() def create_refreshable_session(profile_name, region_name): - """ - Creates a Boto3 session that automatically refreshes credentials - by re-reading the profile (SSO or otherwise) when they expire. - """ bc_session = get_botocore_session() - def refresh_credentials(): - # This function is called by botocore when creds expire temp_session = boto3.Session(profile_name=profile_name, region_name=region_name) creds = temp_session.get_credentials() return { @@ -49,35 +44,23 @@ def refresh_credentials(): "token": creds.token, "expiry_time": creds._expiry_time.isoformat() if creds._expiry_time else (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() } - session_creds = RefreshableCredentials.create_from_metadata( metadata=refresh_credentials(), refresh_using=refresh_credentials, method="sts-assume-role" ) - bc_session._credentials = session_creds bc_session.set_config_variable("region", region_name) return boto3.Session(botocore_session=bc_session) def get_member_session(management_session, account_id, role_name, partition, region_name, verbose): - """Assumes role in member account with maximum allowed duration (3600s).""" sts = management_session.client('sts', region_name=region_name) role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}" try: - # Request 1 hour (3600s) to minimize refresh cycles - response = sts.assume_role( - RoleArn=role_arn, - RoleSessionName="TagDiscoveryScanner", - DurationSeconds=3600 - ) + response = sts.assume_role(RoleArn=role_arn, RoleSessionName="TagDiscoveryScanner", DurationSeconds=3600) c = response['Credentials'] - return boto3.Session( - aws_access_key_id=c['AccessKeyId'], - aws_secret_access_key=c['SecretAccessKey'], - aws_session_token=c['SessionToken'], - region_name=region_name - ) + return boto3.Session(aws_access_key_id=c['AccessKeyId'], aws_secret_access_key=c['SecretAccessKey'], + aws_session_token=c['SessionToken'], region_name=region_name) except Exception as e: if verbose: tqdm.write(f"[!] Auth Error for {account_id}: {str(e)}") return None @@ -85,36 +68,26 @@ def get_member_session(management_session, account_id, role_name, partition, reg def scan_account(account, management_session, role_name, partition, tag_keys, active_tag_keys, region_name, lane_id, account_regex, verbose, bar_width): acc_id = account['Id'] m_session = get_member_session(management_session, acc_id, role_name, partition, region_name, verbose) - if not m_session: return [], acc_id, "N/A", {}, "Auth Fail" - try: alias_resp = m_session.client('iam', region_name=region_name).list_account_aliases() alias = alias_resp.get('AccountAliases', ["N/A"])[0] except Exception: alias = "N/A" - if account_regex and not re.search(account_regex, alias, re.IGNORECASE): return [], acc_id, alias, {}, f"Regex Skip ({alias})" - try: ec2 = m_session.client('ec2', region_name=region_name) active_regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']] except Exception: active_regions = [region_name] - - acc_start = time.time() - findings, global_resources, global_tags_found, regional_metrics = [], set(), set(), [] - + acc_start, findings, global_resources, global_tags_found, regional_metrics = time.time(), [], set(), set(), [] label = f"{acc_id} {alias}".ljust(bar_width) - pbar = tqdm(total=len(tag_keys), desc=f"Lane {lane_id:02d} | {label}", - position=lane_id, leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}') - + pbar = tqdm(total=len(tag_keys), desc=f"Lane {lane_id:02d} | {label}", position=lane_id, leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}') for key in tag_keys: for r in active_regions: - r_start = time.perf_counter() - r_hits, r_res_found, r_tags_found = 0, set(), set() + r_start, r_hits, r_res_found, r_tags_found = time.perf_counter(), 0, set(), set() client = m_session.client('resourcegroupstaggingapi', region_name=r) try: paginator = client.get_paginator('get_resources') @@ -122,53 +95,22 @@ def scan_account(account, management_session, role_name, partition, tag_keys, ac for mapping in page.get('ResourceTagMappingList', []): arn = mapping['ResourceARN'] val = next((t['Value'] for t in mapping['Tags'] if t['Key'] == key), "N/A") - findings.append({ - "tag_name": key, "tag_value": val, "account_id": acc_id, - "account_alias": alias, "region": r, "arn": arn - }) - global_resources.add(arn) - global_tags_found.add(key) - r_res_found.add(arn) - r_tags_found.add(key) - r_hits += 1 + findings.append({"tag_name": key, "tag_value": val, "account_id": acc_id, "account_alias": alias, "region": r, "arn": arn}) + global_resources.add(arn); global_tags_found.add(key); r_res_found.add(arn); r_tags_found.add(key); r_hits += 1 except ClientError as e: if "Throttling" in str(e): time.sleep(1) - r_elapsed = round(time.perf_counter() - r_start, 4) r_entry = next((m for m in regional_metrics if m['region'] == r), None) r_active_found = sorted(list(r_tags_found.intersection(active_tag_keys))) - if not r_entry: - regional_metrics.append({ - "region": r, "hits": r_hits, "unique_resources": len(r_res_found), - "tags_found_count": len(r_tags_found), - "tags_found_list": sorted(list(r_tags_found)), - "tags_found_list_active": r_active_found, - "tags_not_found_count": len(tag_keys) - len(r_tags_found), - "elapsed_sec": r_elapsed - }) + regional_metrics.append({"region": r, "hits": r_hits, "unique_resources": len(r_res_found), "tags_found_count": len(r_tags_found), "tags_found_list": sorted(list(r_tags_found)), "tags_found_list_active": r_active_found, "tags_not_found_count": len(tag_keys) - len(r_tags_found), "elapsed_sec": r_elapsed}) else: - r_entry['hits'] += r_hits - current_tags = set(r_entry['tags_found_list']) | r_tags_found - r_entry['tags_found_list'] = sorted(list(current_tags)) - r_entry['tags_found_list_active'] = sorted(list(current_tags.intersection(active_tag_keys))) - r_entry['tags_found_count'] = len(current_tags) - r_entry['tags_not_found_count'] = len(tag_keys) - len(current_tags) - r_entry['elapsed_sec'] = round(r_entry['elapsed_sec'] + r_elapsed, 4) + r_entry['hits'] += r_hits; current_tags = set(r_entry['tags_found_list']) | r_tags_found + r_entry['tags_found_list'] = sorted(list(current_tags)); r_entry['tags_found_list_active'] = sorted(list(current_tags.intersection(active_tag_keys))) + r_entry['tags_found_count'] = len(current_tags); r_entry['tags_not_found_count'] = len(tag_keys) - len(current_tags); r_entry['elapsed_sec'] = round(r_entry['elapsed_sec'] + r_elapsed, 4) pbar.update(1) - pbar.close() - metrics = { - "global": { - "hits": len(findings), "unique_resources": len(global_resources), - "tags_found_count": len(global_tags_found), - "tags_found_list": sorted(list(global_tags_found)), - "tags_found_list_active": sorted(list(global_tags_found.intersection(active_tag_keys))), - "tags_not_found_count": len(tag_keys) - len(global_tags_found), - "elapsed_sec": round(time.time() - acc_start, 2) - }, - "regions": regional_metrics - } + metrics = {"global": {"hits": len(findings), "unique_resources": len(global_resources), "tags_found_count": len(global_tags_found), "tags_found_list": sorted(list(global_tags_found)), "tags_found_list_active": sorted(list(global_tags_found.intersection(active_tag_keys))), "tags_not_found_count": len(tag_keys) - len(global_tags_found), "elapsed_sec": round(time.time() - acc_start, 2)}, "regions": regional_metrics} return findings, acc_id, alias, metrics, "Success" def main(): @@ -176,13 +118,21 @@ def main(): cmd_line = " ".join(sys.argv) ts = datetime.now().strftime("%Y%m%d_%H%M%S") start_iso, start_ts = datetime.now().isoformat(), time.time() - try: - # Use the new refreshable session for the management connection session = create_refreshable_session(args.profile, args.region) org = session.client('organizations', region_name=args.region) partition = session.client('sts', region_name=args.region).get_caller_identity()['Arn'].split(':')[1] - + unique_accounts = {} + paginator = org.get_paginator('list_accounts') + for page in paginator.paginate(): + for a in page['Accounts']: + if a['Status'] == 'ACTIVE': unique_accounts[a['Id']] = a + if args.list_accounts: + for aid in sorted(unique_accounts.keys()): print(aid) + sys.exit(0) + if not args.role_name or not args.tags_file: + print("[!] Error: --role-name and --tags-file are required unless using --list-accounts.") + sys.exit(1) tag_keys, active_tag_keys = [], set() with open(args.tags_file, mode='r', encoding='utf-8-sig') as f: reader = csv.DictReader(f, skipinitialspace=True) @@ -190,82 +140,37 @@ def main(): key = row.get('TagKey', '').strip().replace('"', '') if key: tag_keys.append(key) - if row.get('Status', '').strip().lower() == 'active': - active_tag_keys.add(key) - + if row.get('Status', '').strip().lower() == 'active': active_tag_keys.add(key) target_ids = [] if args.accounts_from: - with open(args.accounts_from, 'r') as f: - target_ids = [l.strip() for l in f if l.strip()] - - unique_accounts = {} - paginator = org.get_paginator('list_accounts') - for page in paginator.paginate(): - for a in page['Accounts']: - if a['Status'] == 'ACTIVE': - unique_accounts[a['Id']] = a - + with open(args.accounts_from, 'r') as f: target_ids = [l.strip() for l in f if l.strip()] to_process = [v for k, v in unique_accounts.items() if not target_ids or k in target_ids] if args.limit > 0: to_process = to_process[:args.limit] - max_label_len = max([12 + 1 + len(a['Name']) for a in to_process]) + 1 if to_process else 40 - - print(f"\n{'='*85}\nAWS TAG CHECKER v{__version__}\n{'='*85}") - print(f"Profile: {args.profile} | Region: {args.region} | Role: {args.role_name}") - print(f"Tags Read: {len(tag_keys)} ({len(active_tag_keys)} active)") - print(f"Accounts Targeted: {len(to_process)}") - print(f"{'='*85}\n") - + print(f"\n{'='*85}\nAWS TAG CHECKER v{__version__}\n{'='*85}\nProfile: {args.profile} | Region: {args.region} | Role: {args.role_name}\nTags Read: {len(tag_keys)} ({len(active_tag_keys)} active)\nAccounts Targeted: {len(to_process)}\n{'='*85}\n") all_findings, account_results = [], [] overall_pbar = tqdm(total=len(to_process), desc="Total Org Progress", position=0) - with ThreadPoolExecutor(max_workers=args.max_workers) as executor: try: - futures = {executor.submit(scan_account, acc, session, args.role_name, partition, - tag_keys, active_tag_keys, args.region, (i % args.max_workers) + 1, - args.account_regex, args.verbose, max_label_len): acc for i, acc in enumerate(to_process)} + futures = {executor.submit(scan_account, acc, session, args.role_name, partition, tag_keys, active_tag_keys, args.region, (i % args.max_workers) + 1, args.account_regex, args.verbose, max_label_len): acc for i, acc in enumerate(to_process)} for future in as_completed(futures): res, acc_id, alias, m, status = future.result() if status == "Success": - all_findings.extend(res) - account_results.append({"account_id": acc_id, "alias": alias, "global_metrics": m["global"], "regional_metrics": m["regions"]}) - else: - overall_pbar.write(f"[-] {acc_id}: {status}") + all_findings.extend(res); account_results.append({"account_id": acc_id, "alias": alias, "global_metrics": m["global"], "regional_metrics": m["regions"]}) + else: overall_pbar.write(f"[-] {acc_id}: {status}") overall_pbar.update(1) - except KeyboardInterrupt: - executor.shutdown(wait=False, cancel_futures=True) - sys.exit(130) - - overall_pbar.close() - print("\n" * (args.max_workers + 1)) - + except KeyboardInterrupt: executor.shutdown(wait=False, cancel_futures=True); sys.exit(130) + overall_pbar.close(); print("\n" * (args.max_workers + 1)) mem_mb = round(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 2) - total_unique_res = len(set(f['arn'] for f in all_findings)) - all_found_keys = set(f['tag_name'] for f in all_findings) - - output_summary = { - "summary": { - "version": __version__, "command_line": cmd_line, "aws_accounts_scanned": len(account_results), - "tags_read_count": len(tag_keys), "execution_start": start_iso, "execution_end": datetime.now().isoformat(), - "elapsed_sec_total": round(time.time() - start_ts, 2), "max_memory_mb": mem_mb, - "total_hits": sum(a['global_metrics']['hits'] for a in account_results), - "total_unique_resources": total_unique_res, "total_tags_found_count": len(all_found_keys), - "total_tags_not_found_count": len(tag_keys) - len(all_found_keys) - }, - "accounts": account_results - } - + total_unique_res = len(set(f['arn'] for f in all_findings)); all_found_keys = set(f['tag_name'] for f in all_findings) + output_summary = {"summary": {"version": __version__, "command_line": cmd_line, "aws_accounts_scanned": len(account_results), "tags_read_count": len(tag_keys), "execution_start": start_iso, "execution_end": datetime.now().isoformat(), "elapsed_sec_total": round(time.time() - start_ts, 2), "max_memory_mb": mem_mb, "total_hits": sum(a['global_metrics']['hits'] for a in account_results), "total_unique_resources": total_unique_res, "total_tags_found_count": len(all_found_keys), "total_tags_not_found_count": len(tag_keys) - len(all_found_keys)}, "accounts": account_results} sum_f, fin_f = f"{args.output}_summary_{ts}.json", f"{args.output}_findings_{ts}.csv" with open(sum_f, 'w') as f: json.dump(output_summary, f, indent=4) if all_findings: with open(fin_f, 'w', newline='') as f: - writer = csv.DictWriter(f, fieldnames=all_findings[0].keys()) - writer.writeheader(); writer.writerows(all_findings) - + writer = csv.DictWriter(f, fieldnames=all_findings[0].keys()); writer.writeheader(); writer.writerows(all_findings) print(f"[+] Summary: {sum_f}\n[+] Findings: {fin_f}") - - except KeyboardInterrupt: - sys.exit(130) + except KeyboardInterrupt: sys.exit(130) if __name__ == "__main__": main()