Skip to content

Commit

Permalink
add refresher for creds
Browse files Browse the repository at this point in the history
  • Loading branch information
badra001 committed Jan 15, 2026
1 parent 89aad59 commit a998b65
Showing 1 changed file with 51 additions and 31 deletions.
82 changes: 51 additions & 31 deletions local-app/python-tools/cross-organization/tag-checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
import re
import os
import resource
from datetime import datetime
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed
from botocore.exceptions import ClientError
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session as get_botocore_session
from tqdm import tqdm

__version__ = "1.1.13"
__version__ = "1.1.14"

def get_args():
parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}")
parser.add_argument("--role-name", required=True, help="Role to assume in member accounts")
parser.add_argument("--region", required=True, help="Management account region (e.g., us-gov-east-1)")
parser.add_argument("--profile", required=True, help="AWS CLI profile for Management Account")
parser.add_argument("--tags-file", required=True, help="CSV file with TagKey, Type, Status, etc.")
parser.add_argument("--tags-file", required=True, help="CSV file with TagKey, Status, Type, etc.")
parser.add_argument("--max-workers", type=int, default=8, help="Max concurrent account scans")
parser.add_argument("--account-regex", help="Regex to filter accounts by alias")
parser.add_argument("--accounts-from", help="File of Account IDs to process")
Expand All @@ -30,11 +32,45 @@ def get_args():
parser.add_argument("--verbose", action="store_true", help="Enable detailed logging")
return parser.parse_args()

def get_session(management_session, account_id, role_name, partition, region_name, verbose):
def create_refreshable_session(profile_name, region_name):
"""
Creates a Boto3 session that automatically refreshes credentials
by re-reading the profile (SSO or otherwise) when they expire.
"""
bc_session = get_botocore_session()

def refresh_credentials():
# This function is called by botocore when creds expire
temp_session = boto3.Session(profile_name=profile_name, region_name=region_name)
creds = temp_session.get_credentials()
return {
"access_key": creds.access_key,
"secret_key": creds.secret_key,
"token": creds.token,
"expiry_time": creds._expiry_time.isoformat() if creds._expiry_time else (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()
}

session_creds = RefreshableCredentials.create_from_metadata(
metadata=refresh_credentials(),
refresh_using=refresh_credentials,
method="sts-assume-role"
)

bc_session._credentials = session_creds
bc_session.set_config_variable("region", region_name)
return boto3.Session(botocore_session=bc_session)

def get_member_session(management_session, account_id, role_name, partition, region_name, verbose):
"""Assumes role in member account with maximum allowed duration (3600s)."""
sts = management_session.client('sts', region_name=region_name)
role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}"
try:
response = sts.assume_role(RoleArn=role_arn, RoleSessionName="TagDiscoveryScanner")
# Request 1 hour (3600s) to minimize refresh cycles
response = sts.assume_role(
RoleArn=role_arn,
RoleSessionName="TagDiscoveryScanner",
DurationSeconds=3600
)
c = response['Credentials']
return boto3.Session(
aws_access_key_id=c['AccessKeyId'],
Expand All @@ -48,7 +84,7 @@ def get_session(management_session, account_id, role_name, partition, region_nam

def scan_account(account, management_session, role_name, partition, tag_keys, active_tag_keys, region_name, lane_id, account_regex, verbose, bar_width):
acc_id = account['Id']
m_session = get_session(management_session, acc_id, role_name, partition, region_name, verbose)
m_session = get_member_session(management_session, acc_id, role_name, partition, region_name, verbose)

if not m_session:
return [], acc_id, "N/A", {}, "Auth Fail"
Expand All @@ -69,10 +105,7 @@ def scan_account(account, management_session, role_name, partition, tag_keys, ac
active_regions = [region_name]

acc_start = time.time()
findings = []
global_resources = set()
global_tags_found = set()
regional_metrics = []
findings, global_resources, global_tags_found, regional_metrics = [], set(), set(), []

label = f"{acc_id} {alias}".ljust(bar_width)
pbar = tqdm(total=len(tag_keys), desc=f"Lane {lane_id:02d} | {label}",
Expand Down Expand Up @@ -122,15 +155,12 @@ def scan_account(account, management_session, role_name, partition, tag_keys, ac
r_entry['tags_found_count'] = len(current_tags)
r_entry['tags_not_found_count'] = len(tag_keys) - len(current_tags)
r_entry['elapsed_sec'] = round(r_entry['elapsed_sec'] + r_elapsed, 4)

pbar.update(1)

pbar.close()

metrics = {
"global": {
"hits": len(findings),
"unique_resources": len(global_resources),
"hits": len(findings), "unique_resources": len(global_resources),
"tags_found_count": len(global_tags_found),
"tags_found_list": sorted(list(global_tags_found)),
"tags_found_list_active": sorted(list(global_tags_found.intersection(active_tag_keys))),
Expand All @@ -148,14 +178,14 @@ def main():
start_iso, start_ts = datetime.now().isoformat(), time.time()

try:
session = boto3.Session(profile_name=args.profile, region_name=args.region)
# Use the new refreshable session for the management connection
session = create_refreshable_session(args.profile, args.region)
org = session.client('organizations', region_name=args.region)
partition = session.client('sts', region_name=args.region).get_caller_identity()['Arn'].split(':')[1]

tag_keys, active_tag_keys = [], set()
with open(args.tags_file, mode='r', encoding='utf-8-sig') as f:
# Using your specific headers
reader = csv.DictReader(f, skipinitialspace=True)
reader = csv.DictReader(f, skipinitialspace=True)
for row in reader:
key = row.get('TagKey', '').strip().replace('"', '')
if key:
Expand All @@ -178,15 +208,12 @@ def main():
to_process = [v for k, v in unique_accounts.items() if not target_ids or k in target_ids]
if args.limit > 0: to_process = to_process[:args.limit]

# UI: Fixed width for Label = acc_id(12) + space(1) + alias(max) + buffer(1)
max_label_len = max([12 + 1 + len(a['Name']) for a in to_process]) + 1 if to_process else 40

print(f"\n{'='*85}\nAWS TAG CHECKER v{__version__}\n{'='*85}")
print(f"Profile: {args.profile} | Region: {args.region} | Role: {args.role_name}")
print(f"Tags Read: {len(tag_keys)} ({len(active_tag_keys)} active)")
print(f"Accounts Found (Unique): {len(unique_accounts)}")
print(f"Accounts Targeted: {len(to_process)}")
print(f"Arguments: {vars(args)}")
print(f"{'='*85}\n")

all_findings, account_results = [], []
Expand All @@ -212,24 +239,17 @@ def main():
overall_pbar.close()
print("\n" * (args.max_workers + 1))

# Memory usage in MB (Linux RSS is KB)
mem_mb = round(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 2)
total_unique_res = len(set(f['arn'] for f in all_findings))
all_found_keys = set(f['tag_name'] for f in all_findings)

output_summary = {
"summary": {
"version": __version__,
"command_line": cmd_line,
"aws_accounts_scanned": len(account_results),
"tags_read_count": len(tag_keys),
"execution_start": start_iso,
"execution_end": datetime.now().isoformat(),
"elapsed_sec_total": round(time.time() - start_ts, 2),
"max_memory_mb": mem_mb,
"version": __version__, "command_line": cmd_line, "aws_accounts_scanned": len(account_results),
"tags_read_count": len(tag_keys), "execution_start": start_iso, "execution_end": datetime.now().isoformat(),
"elapsed_sec_total": round(time.time() - start_ts, 2), "max_memory_mb": mem_mb,
"total_hits": sum(a['global_metrics']['hits'] for a in account_results),
"total_unique_resources": total_unique_res,
"total_tags_found_count": len(all_found_keys),
"total_unique_resources": total_unique_res, "total_tags_found_count": len(all_found_keys),
"total_tags_not_found_count": len(tag_keys) - len(all_found_keys)
},
"accounts": account_results
Expand Down

0 comments on commit a998b65

Please sign in to comment.