diff --git a/local-app/python-tools/cross-organization/tag-checker.py b/local-app/python-tools/cross-organization/tag-checker.py index 3f2c97c5..64c677b9 100755 --- a/local-app/python-tools/cross-organization/tag-checker.py +++ b/local-app/python-tools/cross-organization/tag-checker.py @@ -8,8 +8,9 @@ import time from datetime import datetime from botocore.exceptions import ClientError +from tqdm import tqdm -__version__ = "1.0.4" +__version__ = "1.0.6" def get_args(): parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}") @@ -28,10 +29,11 @@ def get_session_for_account(management_session, account_id, role_name, partition RoleArn=role_arn, RoleSessionName="TagDiscoveryScanner" ) + creds = response['Credentials'] return boto3.Session( - aws_access_key_id=response['Credentials']['AccessKeyId'], - aws_secret_access_key=response['Credentials']['SecretAccessKey'], - aws_session_token=response['Credentials']['SessionToken'] + aws_access_key_id=creds['AccessKeyId'], + aws_secret_access_key=creds['SecretAccessKey'], + aws_session_token=creds['SessionToken'] ) except ClientError: return None @@ -48,114 +50,120 @@ def main(): args = get_args() start_time_overall = time.time() - session = boto3.Session(profile_name=args.profile) - org_client = session.client('organizations') - partition = session.client('sts').get_caller_identity()['Arn'].split(':')[1] + try: + session = boto3.Session(profile_name=args.profile) + org_client = session.client('organizations') + partition = session.client('sts').get_caller_identity()['Arn'].split(':')[1] + except Exception as e: + print(f"[!] Initialization Error: {e}") + sys.exit(1) # Load Tag Keys tag_keys = [] - with open(args.tags_file, mode='r', encoding='utf-8-sig') as f: - reader = csv.reader(f) - next(reader) - tag_keys = [row[0].strip() for row in reader if row] + try: + with open(args.tags_file, mode='r', encoding='utf-8-sig') as f: + reader = csv.reader(f) + next(reader) + tag_keys = [row[0].strip().replace('"', '') for row in reader if row] + except Exception as e: + print(f"[!] File Error: {e}") + sys.exit(1) + + # Get Account List + all_accounts = [] + paginator = org_client.get_paginator('list_accounts') + for page in paginator.paginate(): + for account in page['Accounts']: + if account['Status'] == 'ACTIVE': + all_accounts.append(account) + + if args.limit > 0: + all_accounts = all_accounts[:args.limit] - print(f"[*] Starting Scan v{__version__} (Testing {len(tag_keys)} tags)") + print(f"[*] Starting Scan v{__version__} | Tags: {len(tag_keys)} | Accounts: {len(all_accounts)}") all_results = [] account_metrics = [] - total_resources_found = 0 - accounts_processed = 0 - paginator = org_client.get_paginator('list_accounts') - for page in paginator.paginate(): - for account in page['Accounts']: - if account['Status'] != 'ACTIVE': continue - - # Apply Account Limit for Testing - if args.limit > 0 and accounts_processed >= args.limit: - break - - acc_start_time = time.time() - acc_id, acc_name = account['Id'], account['Name'] - print(f" --> Account {accounts_processed + 1}: {acc_id} ({acc_name})...", end="\r") - - m_session = get_session_for_account(session, acc_id, args.role_name, partition) - if not m_session: - print(f"\n [!] Skipped {acc_id}: Access Denied (Check Role)") - continue - - alias = get_account_alias(m_session) - resources_in_account = 0 - - try: - ec2 = m_session.client('ec2', region_name=args.region) - regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']] - except ClientError: - regions = [args.region] + # Overall Progress Bar (Position 0) + outer_pbar = tqdm(total=len(all_accounts), desc="Overall Progress", position=0) + + for account in all_accounts: + acc_start_time = time.time() + acc_id, acc_name = account['Id'], account['Name'] + + # Update description to show current account + outer_pbar.set_description(f"Scanning {acc_id}") + + m_session = get_session_for_account(session, acc_id, args.role_name, partition) + if not m_session: + outer_pbar.write(f"[!] Skipping {acc_id}: Cannot assume role") + outer_pbar.update(1) + continue + + alias = get_account_alias(m_session) + hits_in_account = 0 + + try: + ec2 = m_session.client('ec2', region_name=args.region) + regions = [r['RegionName'] for r in ec2.describe_regions()['Regions']] + except ClientError: + regions = [args.region] + + # Inner Progress Bar (Position 1) - One per account + inner_pbar = tqdm(total=len(tag_keys), desc=f"Tags in {acc_id}", position=1, leave=False) + for target_key in tag_keys: + inner_pbar.set_description(f"Checking: {target_key[:20]}") for region in regions: tag_client = m_session.client('resourcegroupstaggingapi', region_name=region) tag_paginator = tag_client.get_paginator('get_resources') - # BUG FIX: Process tags one by one or in small batches. - # AWS TagFilters is an "AND" operation if you provide multiple keys in one filter dict. - # To do an "OR" (Find ANY of these tags), we iterate through each tag key individually. - for target_key in tag_keys: - filter_param = [{'Key': target_key}] - try: - for tag_page in tag_paginator.paginate(TagFilters=filter_param): - for r_mapping in tag_page.get('ResourceTagMappingList', []): - resources_in_account += 1 - arn = r_mapping['ResourceARN'] - # Get the value for the specific key we searched for - val = next((t['Value'] for t in r_mapping['Tags'] if t['Key'] == target_key), "N/A") - - all_results.append({ - "tag_name": target_key, - "tag_value": val, - "account_id": acc_id, - "account_alias": alias, - "region": region, - "arn": arn - }) - except ClientError as e: - if "Throttling" in str(e): - time.sleep(1) # Simple backoff - continue - - acc_elapsed = time.time() - acc_start_time - total_resources_found += resources_in_account - account_metrics.append({ - "account_id": acc_id, - "account_name": acc_name, - "elapsed_seconds": round(acc_elapsed, 2), - "hits_found": resources_in_account - }) - accounts_processed += 1 + try: + for tag_page in tag_paginator.paginate(TagFilters=[{'Key': target_key}]): + for r_mapping in tag_page.get('ResourceTagMappingList', []): + hits_in_account += 1 + val = next((t['Value'] for t in r_mapping['Tags'] if t['Key'] == target_key), "N/A") + all_results.append({ + "tag_name": target_key, + "tag_value": val, + "account_id": acc_id, + "account_alias": alias, + "region": region, + "arn": r_mapping['ResourceARN'] + }) + except ClientError as e: + if "Throttling" in str(e): time.sleep(1) + inner_pbar.update(1) - if args.limit > 0 and accounts_processed >= args.limit: - break + inner_pbar.close() + acc_elapsed = time.time() - acc_start_time + account_metrics.append({ + "account_id": acc_id, "account_name": acc_name, + "elapsed": round(acc_elapsed, 2), "hits": hits_in_account + }) + outer_pbar.update(1) - # Final Output + outer_pbar.close() + + # Final Output and Summary total_elapsed = time.time() - start_time_overall summary = { - "scan_version": __version__, - "total_accounts": accounts_processed, + "version": __version__, + "total_accounts": len(all_accounts), "total_hits": len(all_results), - "total_time": f"{round(total_elapsed, 2)}s", - "account_breakdown": account_metrics + "total_time_seconds": round(total_elapsed, 2), + "accounts": account_metrics } with open('summary_metrics.json', 'w') as f: json.dump(summary, f, indent=4) if all_results: - keys = all_results[0].keys() with open('findings.csv', 'w', newline='') as f: - writer = csv.DictWriter(f, fieldnames=keys) + writer = csv.DictWriter(f, fieldnames=all_results[0].keys()) writer.writeheader() writer.writerows(all_results) - print(f"\n\n[+] Scan Complete. Found {len(all_results)} tag instances across {accounts_processed} accounts.") - + print(f"\n[+] Scan Complete. Found {len(all_results)} hits.") if __name__ == "__main__": main()