Skip to content

Commit

Permalink
add more progress bars
Browse files Browse the repository at this point in the history
  • Loading branch information
badra001 committed Jan 15, 2026
1 parent 11296d1 commit 8522a09
Showing 1 changed file with 132 additions and 129 deletions.
261 changes: 132 additions & 129 deletions local-app/python-tools/cross-organization/tag-checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,164 +6,167 @@
import argparse
import sys
import time
import re
import os
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from botocore.exceptions import ClientError
from tqdm import tqdm

__version__ = "1.0.6"
__version__ = "1.1.0"

def get_args():
parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}")
parser.add_argument("--role-name", required=True, help="Role to assume in member accounts")
parser.add_argument("--region", required=True, help="Primary region for API initialization")
parser.add_argument("--profile", required=True, help="AWS CLI profile for Management Account")
parser.add_argument("--tags-file", required=True, help="CSV file with Tag Key in the first column")
parser.add_argument("--limit", type=int, default=0, help="Limit scan to X number of accounts (0 for no limit)")
parser.add_argument("--max-workers", type=int, default=8, help="Max concurrent account scans (default: 8)")
parser.add_argument("--account-regex", help="Regex to filter accounts by alias/name")
parser.add_argument("--accounts-from", help="File containing specific Account IDs to process")
parser.add_argument("--output", default="tag_checker_findings", help="Prefix for output files")
parser.add_argument("--limit", type=int, default=0, help="Hard limit on total accounts to scan")
return parser.parse_args()

def get_session_for_account(management_session, account_id, role_name, partition):
def get_session(management_session, account_id, role_name, partition):
sts = management_session.client('sts')
role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}"
try:
response = sts.assume_role(
RoleArn=role_arn,
RoleSessionName="TagDiscoveryScanner"
)
creds = response['Credentials']
return boto3.Session(
aws_access_key_id=creds['AccessKeyId'],
aws_secret_access_key=creds['SecretAccessKey'],
aws_session_token=creds['SessionToken']
)
except ClientError:
return None

def get_account_alias(session):
response = sts.assume_role(RoleArn=role_arn, RoleSessionName="TagDiscoveryScanner")
c = response['Credentials']
return boto3.Session(aws_access_key_id=c['AccessKeyId'],
aws_secret_access_key=c['SecretAccessKey'],
aws_session_token=c['SessionToken'])
except: return None

def get_alias_fixed(session):
"""Accurately retrieves the IAM alias from the member account."""
try:
iam = session.client('iam')
aliases = iam.list_account_aliases()['AccountAliases']
return aliases[0] if aliases else "No Alias"
except ClientError:
return "Unknown"
return session.client('iam').list_account_aliases()['AccountAliases'][0]
except: return "No Alias"

def main():
args = get_args()
start_time_overall = time.time()
def scan_account(account, management_session, role_name, partition, tag_keys, region_name, position):
acc_id, acc_name = account['Id'], account['Name']
m_session = get_session(management_session, acc_id, role_name, partition)

try:
session = boto3.Session(profile_name=args.profile)
org_client = session.client('organizations')
partition = session.client('sts').get_caller_identity()['Arn'].split(':')[1]
except Exception as e:
print(f"[!] Initialization Error: {e}")
sys.exit(1)
if not m_session:
return [], acc_id, "AUTH_FAIL", 0

alias = get_alias_fixed(m_session)
findings = []

# Inner progress bar (The "Lane")
pbar = tqdm(total=len(tag_keys), desc=f"Lane {position}: {acc_id}", position=position, leave=False)

# Load Tag Keys
tag_keys = []
try:
with open(args.tags_file, mode='r', encoding='utf-8-sig') as f:
reader = csv.reader(f)
next(reader)
tag_keys = [row[0].strip().replace('"', '') for row in reader if row]
except Exception as e:
print(f"[!] File Error: {e}")
sys.exit(1)

# Get Account List
all_accounts = []
paginator = org_client.get_paginator('list_accounts')
for page in paginator.paginate():
for account in page['Accounts']:
if account['Status'] == 'ACTIVE':
all_accounts.append(account)
ec2 = m_session.client('ec2', region_name=region_name)
regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']]
except: regions = [region_name]

for key in tag_keys:
for r in regions:
client = m_session.client('resourcegroupstaggingapi', region_name=r)
try:
paginator = client.get_paginator('get_resources')
for page in paginator.paginate(TagFilters=[{'Key': key}]):
for mapping in page.get('ResourceTagMappingList', []):
val = next((t['Value'] for t in mapping['Tags'] if t['Key'] == key), "N/A")
findings.append({
"tag_name": key, "tag_value": val, "account_id": acc_id,
"account_alias": alias, "region": r, "arn": mapping['ResourceARN']
})
except ClientError as e:
if "Throttling" in str(e): time.sleep(1)
pbar.update(1)

if args.limit > 0:
all_accounts = all_accounts[:args.limit]
pbar.close()
return findings, acc_id, alias, len(findings)

print(f"[*] Starting Scan v{__version__} | Tags: {len(tag_keys)} | Accounts: {len(all_accounts)}")
def main():
args = get_args()
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
start_time = time.time()

all_results = []
account_metrics = []

# Overall Progress Bar (Position 0)
outer_pbar = tqdm(total=len(all_accounts), desc="Overall Progress", position=0)
# 1. Initialization
session = boto3.Session(profile_name=args.profile)
org = session.client('organizations')
partition = session.client('sts').get_caller_identity()['Arn'].split(':')[1]

# 2. Pre-scan Summary
print(f"\n{'='*50}\nAWS TAG CHECKER v{__version__}\n{'='*50}")
print(f"Profile: {args.profile} | Region: {args.region} | Role: {args.role_name}")

# 3. Load Tags and Account Filters
with open(args.tags_file, mode='r', encoding='utf-8-sig') as f:
tag_keys = [row[0].strip().replace('"', '') for row in list(csv.reader(f))[1:] if row]

allowed_ids = []
if args.accounts_from:
with open(args.accounts_from, 'r') as f:
allowed_ids = [line.strip() for line in f if line.strip()]

# 4. Fetch and Filter Accounts
all_raw_accounts = []
paginator = org.get_paginator('list_accounts')
for page in paginator.paginate():
all_raw_accounts.extend(page['Accounts'])

for account in all_accounts:
acc_start_time = time.time()
acc_id, acc_name = account['Id'], account['Name']

# Update description to show current account
outer_pbar.set_description(f"Scanning {acc_id}")
to_process = []
for acc in all_raw_accounts:
if acc['Status'] != 'ACTIVE': continue
if allowed_ids and acc['Id'] not in allowed_ids: continue

m_session = get_session_for_account(session, acc_id, args.role_name, partition)
if not m_session:
outer_pbar.write(f"[!] Skipping {acc_id}: Cannot assume role")
outer_pbar.update(1)
continue

alias = get_account_alias(m_session)
hits_in_account = 0

try:
ec2 = m_session.client('ec2', region_name=args.region)
regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']]
except ClientError:
regions = [args.region]

# Inner Progress Bar (Position 1) - One per account
inner_pbar = tqdm(total=len(tag_keys), desc=f"Tags in {acc_id}", position=1, leave=False)

for target_key in tag_keys:
inner_pbar.set_description(f"Checking: {target_key[:20]}")
for region in regions:
tag_client = m_session.client('resourcegroupstaggingapi', region_name=region)
tag_paginator = tag_client.get_paginator('get_resources')

try:
for tag_page in tag_paginator.paginate(TagFilters=[{'Key': target_key}]):
for r_mapping in tag_page.get('ResourceTagMappingList', []):
hits_in_account += 1
val = next((t['Value'] for t in r_mapping['Tags'] if t['Key'] == target_key), "N/A")
all_results.append({
"tag_name": target_key,
"tag_value": val,
"account_id": acc_id,
"account_alias": alias,
"region": region,
"arn": r_mapping['ResourceARN']
})
except ClientError as e:
if "Throttling" in str(e): time.sleep(1)
inner_pbar.update(1)
# Temp check for Regex (Requires Alias which we fetch inside threads,
# so here we check Org Name first, then Alias in-thread)
if args.account_regex and not re.search(args.account_regex, acc['Name']):
# We'll re-verify Alias inside the thread for regex accuracy
pass
to_process.append(acc)

if args.limit > 0: to_process = to_process[:args.limit]
print(f"Accounts Found: {len(all_raw_accounts)} | Targeted: {len(to_process)}")
print(f"{'='*50}\n")

# 5. Multi-threaded Execution
all_findings = []
summary_data = []

# Overall Progress Bar
overall_pbar = tqdm(total=len(to_process), desc="Total Progress", position=0)

with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
# Use a map to track positions (lanes 1 through max_workers)
futures = {executor.submit(scan_account, acc, session, args.role_name, partition,
tag_keys, args.region, (i % args.max_workers) + 1): acc
for i, acc in enumerate(to_process)}

inner_pbar.close()
acc_elapsed = time.time() - acc_start_time
account_metrics.append({
"account_id": acc_id, "account_name": acc_name,
"elapsed": round(acc_elapsed, 2), "hits": hits_in_account
})
outer_pbar.update(1)

outer_pbar.close()

# Final Output and Summary
total_elapsed = time.time() - start_time_overall
summary = {
"version": __version__,
"total_accounts": len(all_accounts),
"total_hits": len(all_results),
"total_time_seconds": round(total_elapsed, 2),
"accounts": account_metrics
}

with open('summary_metrics.json', 'w') as f: json.dump(summary, f, indent=4)
if all_results:
with open('findings.csv', 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=all_results[0].keys())
for future in as_completed(futures):
res, acc_id, alias, count = future.result()

# Post-thread Regex filtering for Alias accuracy
if args.account_regex and not re.search(args.account_regex, alias):
overall_pbar.write(f"[-] Skipped {acc_id} ({alias}): Regex mismatch")
else:
all_findings.extend(res)
summary_data.append({"account_id": acc_id, "alias": alias, "hits": count})

overall_pbar.update(1)

overall_pbar.close()

# 6. Output Files
csv_file = f"{args.output}_{ts}.csv"
json_file = f"{args.output}_{ts}.json"

with open(json_file, 'w') as f: json.dump(all_findings, f, indent=4)
if all_findings:
with open(csv_file, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=all_findings[0].keys())
writer.writeheader()
writer.writerows(all_results)
writer.writerows(all_findings)

print(f"\n[+] Scan Complete. Found {len(all_results)} hits.")
print(f"\n[+] Done! Scanned {len(summary_data)} accounts in {round(time.time()-start_time, 2)}s")
print(f"[+] Findings: {csv_file}")

if __name__ == "__main__":
main()

0 comments on commit 8522a09

Please sign in to comment.