From a998b655b4de9c9b4cdfb97ab0a1581c7755cdd8 Mon Sep 17 00:00:00 2001 From: badra001 Date: Thu, 15 Jan 2026 12:30:49 -0500 Subject: [PATCH] add refresher for creds --- .../cross-organization/tag-checker.py | 82 ++++++++++++------- 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/local-app/python-tools/cross-organization/tag-checker.py b/local-app/python-tools/cross-organization/tag-checker.py index 939afc2b..620888da 100755 --- a/local-app/python-tools/cross-organization/tag-checker.py +++ b/local-app/python-tools/cross-organization/tag-checker.py @@ -9,19 +9,21 @@ import re import os import resource -from datetime import datetime +from datetime import datetime, timezone from concurrent.futures import ThreadPoolExecutor, as_completed from botocore.exceptions import ClientError +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session as get_botocore_session from tqdm import tqdm -__version__ = "1.1.13" +__version__ = "1.1.14" def get_args(): parser = argparse.ArgumentParser(description=f"AWS Org Tag Scanner v{__version__}") parser.add_argument("--role-name", required=True, help="Role to assume in member accounts") parser.add_argument("--region", required=True, help="Management account region (e.g., us-gov-east-1)") parser.add_argument("--profile", required=True, help="AWS CLI profile for Management Account") - parser.add_argument("--tags-file", required=True, help="CSV file with TagKey, Type, Status, etc.") + parser.add_argument("--tags-file", required=True, help="CSV file with TagKey, Status, Type, etc.") parser.add_argument("--max-workers", type=int, default=8, help="Max concurrent account scans") parser.add_argument("--account-regex", help="Regex to filter accounts by alias") parser.add_argument("--accounts-from", help="File of Account IDs to process") @@ -30,11 +32,45 @@ def get_args(): parser.add_argument("--verbose", action="store_true", help="Enable detailed logging") return parser.parse_args() -def get_session(management_session, account_id, role_name, partition, region_name, verbose): +def create_refreshable_session(profile_name, region_name): + """ + Creates a Boto3 session that automatically refreshes credentials + by re-reading the profile (SSO or otherwise) when they expire. + """ + bc_session = get_botocore_session() + + def refresh_credentials(): + # This function is called by botocore when creds expire + temp_session = boto3.Session(profile_name=profile_name, region_name=region_name) + creds = temp_session.get_credentials() + return { + "access_key": creds.access_key, + "secret_key": creds.secret_key, + "token": creds.token, + "expiry_time": creds._expiry_time.isoformat() if creds._expiry_time else (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + } + + session_creds = RefreshableCredentials.create_from_metadata( + metadata=refresh_credentials(), + refresh_using=refresh_credentials, + method="sts-assume-role" + ) + + bc_session._credentials = session_creds + bc_session.set_config_variable("region", region_name) + return boto3.Session(botocore_session=bc_session) + +def get_member_session(management_session, account_id, role_name, partition, region_name, verbose): + """Assumes role in member account with maximum allowed duration (3600s).""" sts = management_session.client('sts', region_name=region_name) role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}" try: - response = sts.assume_role(RoleArn=role_arn, RoleSessionName="TagDiscoveryScanner") + # Request 1 hour (3600s) to minimize refresh cycles + response = sts.assume_role( + RoleArn=role_arn, + RoleSessionName="TagDiscoveryScanner", + DurationSeconds=3600 + ) c = response['Credentials'] return boto3.Session( aws_access_key_id=c['AccessKeyId'], @@ -48,7 +84,7 @@ def get_session(management_session, account_id, role_name, partition, region_nam def scan_account(account, management_session, role_name, partition, tag_keys, active_tag_keys, region_name, lane_id, account_regex, verbose, bar_width): acc_id = account['Id'] - m_session = get_session(management_session, acc_id, role_name, partition, region_name, verbose) + m_session = get_member_session(management_session, acc_id, role_name, partition, region_name, verbose) if not m_session: return [], acc_id, "N/A", {}, "Auth Fail" @@ -69,10 +105,7 @@ def scan_account(account, management_session, role_name, partition, tag_keys, ac active_regions = [region_name] acc_start = time.time() - findings = [] - global_resources = set() - global_tags_found = set() - regional_metrics = [] + findings, global_resources, global_tags_found, regional_metrics = [], set(), set(), [] label = f"{acc_id} {alias}".ljust(bar_width) pbar = tqdm(total=len(tag_keys), desc=f"Lane {lane_id:02d} | {label}", @@ -122,15 +155,12 @@ def scan_account(account, management_session, role_name, partition, tag_keys, ac r_entry['tags_found_count'] = len(current_tags) r_entry['tags_not_found_count'] = len(tag_keys) - len(current_tags) r_entry['elapsed_sec'] = round(r_entry['elapsed_sec'] + r_elapsed, 4) - pbar.update(1) pbar.close() - metrics = { "global": { - "hits": len(findings), - "unique_resources": len(global_resources), + "hits": len(findings), "unique_resources": len(global_resources), "tags_found_count": len(global_tags_found), "tags_found_list": sorted(list(global_tags_found)), "tags_found_list_active": sorted(list(global_tags_found.intersection(active_tag_keys))), @@ -148,14 +178,14 @@ def main(): start_iso, start_ts = datetime.now().isoformat(), time.time() try: - session = boto3.Session(profile_name=args.profile, region_name=args.region) + # Use the new refreshable session for the management connection + session = create_refreshable_session(args.profile, args.region) org = session.client('organizations', region_name=args.region) partition = session.client('sts', region_name=args.region).get_caller_identity()['Arn'].split(':')[1] tag_keys, active_tag_keys = [], set() with open(args.tags_file, mode='r', encoding='utf-8-sig') as f: - # Using your specific headers - reader = csv.DictReader(f, skipinitialspace=True) + reader = csv.DictReader(f, skipinitialspace=True) for row in reader: key = row.get('TagKey', '').strip().replace('"', '') if key: @@ -178,15 +208,12 @@ def main(): to_process = [v for k, v in unique_accounts.items() if not target_ids or k in target_ids] if args.limit > 0: to_process = to_process[:args.limit] - # UI: Fixed width for Label = acc_id(12) + space(1) + alias(max) + buffer(1) max_label_len = max([12 + 1 + len(a['Name']) for a in to_process]) + 1 if to_process else 40 print(f"\n{'='*85}\nAWS TAG CHECKER v{__version__}\n{'='*85}") print(f"Profile: {args.profile} | Region: {args.region} | Role: {args.role_name}") print(f"Tags Read: {len(tag_keys)} ({len(active_tag_keys)} active)") - print(f"Accounts Found (Unique): {len(unique_accounts)}") print(f"Accounts Targeted: {len(to_process)}") - print(f"Arguments: {vars(args)}") print(f"{'='*85}\n") all_findings, account_results = [], [] @@ -212,24 +239,17 @@ def main(): overall_pbar.close() print("\n" * (args.max_workers + 1)) - # Memory usage in MB (Linux RSS is KB) mem_mb = round(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 2) total_unique_res = len(set(f['arn'] for f in all_findings)) all_found_keys = set(f['tag_name'] for f in all_findings) output_summary = { "summary": { - "version": __version__, - "command_line": cmd_line, - "aws_accounts_scanned": len(account_results), - "tags_read_count": len(tag_keys), - "execution_start": start_iso, - "execution_end": datetime.now().isoformat(), - "elapsed_sec_total": round(time.time() - start_ts, 2), - "max_memory_mb": mem_mb, + "version": __version__, "command_line": cmd_line, "aws_accounts_scanned": len(account_results), + "tags_read_count": len(tag_keys), "execution_start": start_iso, "execution_end": datetime.now().isoformat(), + "elapsed_sec_total": round(time.time() - start_ts, 2), "max_memory_mb": mem_mb, "total_hits": sum(a['global_metrics']['hits'] for a in account_results), - "total_unique_resources": total_unique_res, - "total_tags_found_count": len(all_found_keys), + "total_unique_resources": total_unique_res, "total_tags_found_count": len(all_found_keys), "total_tags_not_found_count": len(tag_keys) - len(all_found_keys) }, "accounts": account_results