Skip to content

Commit

Permalink
add threading
Browse files Browse the repository at this point in the history
  • Loading branch information
badra001 committed Jan 16, 2026
1 parent 234bfbc commit dba0a12
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions local-app/python-tools/cross-organization/tag-checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@
import re
import os
import resource
import threading
from datetime import datetime, timezone, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from botocore.exceptions import ClientError
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session as get_botocore_session
from tqdm import tqdm

__version__ = "1.1.16"
__version__ = "1.1.17"

# Counter for global sequence tracking
ACCOUNT_COUNTER = 0
COUNTER_LOCK = threading.Lock()

def get_args():
parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}")
Expand All @@ -31,7 +36,7 @@ def get_args():
parser.add_argument("--output", default="tag_checker", help="Prefix for output files")
parser.add_argument("--limit", type=int, default=0, help="Limit total accounts processed")
parser.add_argument("--verbose", action="store_true", help="Enable detailed logging")
parser.add_argument("--list-accounts", action="store_true", help="List all unique Account IDs and exit")
parser.add_argument("--list-accounts", action="store_true", help="List Account IDs and exit")
return parser.parse_args()

def create_refreshable_session(profile_name, region_name):
Expand All @@ -46,9 +51,7 @@ def refresh_credentials():
"expiry_time": creds._expiry_time.isoformat() if creds._expiry_time else (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()
}
session_creds = RefreshableCredentials.create_from_metadata(
metadata=refresh_credentials(),
refresh_using=refresh_credentials,
method="sts-assume-role"
metadata=refresh_credentials(), refresh_using=refresh_credentials, method="sts-assume-role"
)
bc_session._credentials = session_creds
bc_session.set_config_variable("region", region_name)
Expand All @@ -67,34 +70,35 @@ def get_member_session(management_session, account_id, role_name, partition, reg
return None

def scan_account(account, management_session, role_name, partition, tag_keys, active_tag_keys, region_name, lane_id, account_regex, region_regex_str, verbose, bar_width):
global ACCOUNT_COUNTER
with COUNTER_LOCK:
ACCOUNT_COUNTER += 1
current_index = ACCOUNT_COUNTER

acc_id = account['Id']
m_session = get_member_session(management_session, acc_id, role_name, partition, region_name, verbose)
if not m_session:
return [], acc_id, "N/A", {}, "Auth Fail"

if not m_session: return [], acc_id, "N/A", {}, "Auth Fail"

try:
alias_resp = m_session.client('iam', region_name=region_name).list_account_aliases()
alias = alias_resp.get('AccountAliases', ["N/A"])[0]
except Exception:
alias = "N/A"

except Exception: alias = "N/A"

if account_regex and not re.search(account_regex, alias, re.IGNORECASE):
return [], acc_id, alias, {}, f"Regex Skip ({alias})"

try:
ec2 = m_session.client('ec2', region_name=region_name)
all_regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']]
if region_regex_str:
active_regions = [r for r in all_regions if re.search(region_regex_str, r, re.IGNORECASE)]
else:
active_regions = all_regions
except Exception:
active_regions = [region_name]
active_regions = [r for r in all_regions if re.search(region_regex_str, r, re.IGNORECASE)] if region_regex_str else all_regions
except: active_regions = [region_name]

acc_start, findings, global_resources, global_tags_found, regional_metrics = time.time(), [], set(), set(), []
label = f"{acc_id} {alias}".ljust(bar_width)
pbar = tqdm(total=len(tag_keys), desc=f"Lane {lane_id:02d} | {label}", position=lane_id, leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}')

# FORMAT: {lane_id} | {index} | {acc_id} {alias}
label = f"{lane_id:02d} | {current_index:03d} | {acc_id} {alias}".ljust(bar_width)
pbar = tqdm(total=len(tag_keys), desc=label, position=lane_id, leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}')

for key in tag_keys:
for r in active_regions:
r_start, r_hits, r_res_found, r_tags_found = time.perf_counter(), 0, set(), set()
Expand All @@ -112,9 +116,8 @@ def scan_account(account, management_session, role_name, partition, tag_keys, ac

r_elapsed = round(time.perf_counter() - r_start, 4)
r_entry = next((m for m in regional_metrics if m['region'] == r), None)
r_active_found = sorted(list(r_tags_found.intersection(active_tag_keys)))
if not r_entry:
regional_metrics.append({"region": r, "hits": r_hits, "unique_resources": len(r_res_found), "tags_found_count": len(r_tags_found), "tags_found_list": sorted(list(r_tags_found)), "tags_found_list_active": r_active_found, "tags_not_found_count": len(tag_keys) - len(r_tags_found), "elapsed_sec": r_elapsed})
regional_metrics.append({"region": r, "hits": r_hits, "unique_resources": len(r_res_found), "tags_found_count": len(r_tags_found), "tags_found_list": sorted(list(r_tags_found)), "tags_found_list_active": sorted(list(r_tags_found.intersection(active_tag_keys))), "tags_not_found_count": len(tag_keys) - len(r_tags_found), "elapsed_sec": r_elapsed})
else:
r_entry['hits'] += r_hits; current_tags = set(r_entry['tags_found_list']) | r_tags_found
r_entry['tags_found_list'] = sorted(list(current_tags)); r_entry['tags_found_list_active'] = sorted(list(current_tags.intersection(active_tag_keys)))
Expand Down Expand Up @@ -145,15 +148,11 @@ def main():
for aid in sorted(unique_accounts.keys()): print(aid)
sys.exit(0)

if not args.role_name or not args.tags_file:
print("[!] Error: --role-name and --tags-file are required unless using --list-accounts.")
sys.exit(1)

tag_keys, active_tag_keys = [], set()
with open(args.tags_file, mode='r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f, skipinitialspace=True)
for row in reader:
key = row.get('TagKey', '').strip().replace('"', '')
key = row.get('TagKey', '').strip()
if key:
tag_keys.append(key)
if row.get('Status', '').strip().lower() == 'active': active_tag_keys.add(key)
Expand All @@ -165,17 +164,17 @@ def main():
to_process = [v for k, v in unique_accounts.items() if not target_ids or k in target_ids]
if args.limit > 0: to_process = to_process[:args.limit]

max_label_len = max([12 + 1 + len(a['Name']) for a in to_process]) + 1 if to_process else 40
# UI Width Calculation: "01 | 001 | 123456789012 MyAlias"
max_label_len = max([3 + 3 + 3 + 12 + 1 + len(a['Name']) for a in to_process]) + 2 if to_process else 50

print(f"\n{'='*85}\nAWS TAG CHECKER v{__version__}\n{'='*85}")
print(f"Profile: {args.profile} | Region: {args.region} | Role: {args.role_name}")
if args.region_regex: print(f"Region Regex: {args.region_regex}")
print(f"Tags Read: {len(tag_keys)} ({len(active_tag_keys)} active)")
print(f"Accounts Targeted: {len(to_process)}")
print(f"{'='*85}\n")
print(f"Accounts Targeted: {len(to_process)} (Unique Total: {len(unique_accounts)})")
print(f"Thread Count: {args.max_workers}\n{'='*85}\n")

all_findings, account_results = [], []
overall_pbar = tqdm(total=len(to_process), desc="Total Org Progress", position=0)
overall_pbar = tqdm(total=len(to_process), desc="Overall Progress", position=0)

with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
try:
Expand All @@ -184,24 +183,21 @@ def main():
res, acc_id, alias, m, status = future.result()
if status == "Success":
all_findings.extend(res); account_results.append({"account_id": acc_id, "alias": alias, "global_metrics": m["global"], "regional_metrics": m["regions"]})
else: overall_pbar.write(f"[-] {acc_id}: {status}")
overall_pbar.update(1)
except KeyboardInterrupt: executor.shutdown(wait=False, cancel_futures=True); sys.exit(130)

overall_pbar.close(); print("\n" * (args.max_workers + 1))

mem_mb = round(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 2)
total_unique_res = len(set(f['arn'] for f in all_findings))
all_found_keys = set(f['tag_name'] for f in all_findings)

output_summary = {"summary": {"version": __version__, "command_line": cmd_line, "aws_accounts_scanned": len(account_results), "tags_read_count": len(tag_keys), "execution_start": start_iso, "execution_end": datetime.now().isoformat(), "elapsed_sec_total": round(time.time() - start_ts, 2), "max_memory_mb": mem_mb, "total_hits": sum(a['global_metrics']['hits'] for a in account_results), "total_unique_resources": total_unique_res, "total_tags_found_count": len(all_found_keys), "total_tags_not_found_count": len(tag_keys) - len(all_found_keys)}, "accounts": account_results}
output_summary = {"summary": {"version": __version__, "command_line": cmd_line, "aws_accounts_scanned": len(account_results), "tags_read_count": len(tag_keys), "execution_start": start_iso, "execution_end": datetime.now().isoformat(), "elapsed_sec_total": round(time.time() - start_ts, 2), "max_memory_mb": mem_mb, "total_hits": sum(a['global_metrics']['hits'] for a in account_results), "total_unique_resources": total_unique_res, "total_tags_found_count": len(all_found_keys)}, "accounts": account_results}

sum_f, fin_f = f"{args.output}_summary_{ts}.json", f"{args.output}_findings_{ts}.csv"
with open(sum_f, 'w') as f: json.dump(output_summary, f, indent=4)
if all_findings:
with open(fin_f, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=all_findings[0].keys()); writer.writeheader(); writer.writerows(all_findings)

print(f"[+] Summary: {sum_f}\n[+] Findings: {fin_f}")
except KeyboardInterrupt: sys.exit(130)

Expand Down

0 comments on commit dba0a12

Please sign in to comment.