Skip to content

Commit

Permalink
update for fixing remediation creds
Browse files Browse the repository at this point in the history
  • Loading branch information
badra001 committed Mar 10, 2026
1 parent 28f4614 commit 2e82a98
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 67 deletions.
180 changes: 151 additions & 29 deletions local-app/python-tools/cross-organization/org_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def tqdm(iterable, **kwargs): return iterable

# --- VERSIONING ---
__version__ = "2.0.0"
__version__ = "2.2.0"

class OrgTaskRunner:
def __init__(self, args):
Expand All @@ -29,10 +29,63 @@ def __init__(self, args):
self.start_time = 0
self.org_id = "unknown"

# ... [get_ou_path and process_account logic restored from v1.6.7] ...
def get_ou_path(self, org_client, entity_id):
"""Restored from v1.6.7: Recursively builds the OU path string."""
if entity_id in self.hierarchy_cache: return self.hierarchy_cache[entity_id]
if entity_id.startswith('r-'):
self.hierarchy_cache[entity_id] = (None, entity_id)
return None, entity_id
try:
parents = org_client.list_parents(ChildId=entity_id)['Parents']
p_id = parents[0]['Id'] if parents else None
ou_desc = org_client.describe_organizational_unit(OrganizationalUnitId=entity_id)
ou_name = ou_desc['OrganizationalUnit']['Name']
p_path, _ = self.get_ou_path(org_client, p_id) if p_id else (None, None)
path = f"{p_path}:{ou_name}" if p_path else ou_name
self.hierarchy_cache[entity_id] = (path, entity_id)
return path, entity_id
except: return "Unknown", entity_id

def process_account(self, acc, partition, tasks):
"""Restored from v1.6.7: Logic for multi-threaded audit tasks."""
thread_session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region)
sts, org = thread_session.client('sts'), thread_session.client('organizations')
acc_id, acc_name = acc['Id'], acc['Name']
role_arn = f"arn:{partition}:iam::{acc_id}:role/{self.args.role_name}"

parents = org.list_parents(ChildId=acc_id).get('Parents', [])
ou_path, ou_id = self.get_ou_path(org, parents[0]['Id']) if parents else ("Orphaned", "N/A")
ou_path = ou_path if ou_path else "Root"

account_metadata = {
"org_id": self.org_id,
"account_id": acc_id,
"account_name": acc_name,
"alias": "N/A",
"ou_path": ou_path,
"ou_id": ou_id
}

account_results = {"metadata": account_metadata, "task_data": {}}

try:
assumed = sts.assume_role(RoleArn=role_arn, RoleSessionName="OrgRunner")
m_sess = boto3.Session(
aws_access_key_id=assumed['Credentials']['AccessKeyId'],
aws_secret_access_key=assumed['Credentials']['SecretAccessKey'],
aws_session_token=assumed['Credentials']['SessionToken'],
region_name=self.args.region
)
for mod_name, t_func in tasks:
res = t_func(m_sess, acc_id, acc_name, self.args.region)
account_results["metadata"]["alias"] = res.get("alias", "N/A")
account_results["task_data"][mod_name] = res.get("data", {})
return account_results, None
except Exception as e:
return None, f"FAILED {acc_name}: {str(e)}"

def run_remediation(self):
"""NEW: Logic for processing remediation hit-lists."""
"""NEW: Processes remediation instructions with partition awareness."""
self.start_time = time.perf_counter()
module_name = self.args.remediate.replace('.py', '')
instruction_file = f"{module_name}.txt"
Expand All @@ -41,7 +94,10 @@ def run_remediation(self):
print(f"Error: Instruction file {instruction_file} not found.")
return

# Dynamically import the remediation module
base_session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region)
sts = base_session.client('sts')
partition = sts.get_caller_identity()['Arn'].split(':')[1]

sys.path.append(os.getcwd())
try:
rem_module = importlib.import_module(module_name)
Expand All @@ -52,55 +108,121 @@ def run_remediation(self):

print("-" * 100)
print(f"AWS ORG REMEDIATION RUNNER - v{__version__}")
print(f" Module: {module_name}")
print(f" Instruction File: {instruction_file}")
print(f" Partition: {partition}")
print(f" Target Role: {self.args.role_name}")
print(f" Dry Run: {self.args.dry_run}")
print("-" * 100)

with open(instruction_file, 'r') as f:
instructions = [line.strip() for line in f if line.strip()]

remediation_logs = []
logs = []
with tqdm(total=len(instructions), desc="Remediating", unit="task", colour="red") as pbar:
for line in instructions:
# Execute the module's remediation task
result = rem_func(line, dry_run=self.args.dry_run)
if result:
remediation_logs.append(result)
res = rem_func(line, base_session, self.args.role_name, partition,
dry_run=self.args.dry_run, rollback=self.args.rollback)
if res: logs.append(res)
pbar.update(1)

# Output remediation results
ds = datetime.now().strftime("%Y%m%dT%H%M%S")
out_path = f"remediation_{module_name}.{ds}.json"
with open(out_path, 'w') as f:
json.dump(remediation_logs, f, indent=2)
out = f"remediation_log.{module_name}.{ds}.json"
with open(out, 'w') as f: json.dump(logs, f, indent=2)

print("-" * 100)
print(f"COMPLETED: {round(time.perf_counter() - self.start_time, 2)}s elapsed")
print(f"REMEDIATION LOG CREATED: {out_path}")
print(f"LOG CREATED: {out}")
print("-" * 100)

def run_audit(self):
"""Restored from v1.6.7: The standard check/audit workflow."""
self.start_time = time.perf_counter()
session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region)
org_client = session.client('organizations')
sts_client = session.client('sts')
iam_client = session.client('iam')

caller = sts_client.get_caller_identity()
partition = caller['Arn'].split(':')[1]

org_info = org_client.describe_organization()['Organization']
self.org_id = org_info['Id']
master_id = org_info['MasterAccountId']

try:
master_aliases = iam_client.list_account_aliases()['AccountAliases']
master_alias = master_aliases[0] if master_aliases else "None"
except: master_alias = "Unknown (Check Permissions)"

tasks, check_info = [], []
if self.args.enable_checks:
sys.path.append(os.getcwd())
for m in self.args.enable_checks:
mod_name = m.replace('.py', '')
module = importlib.import_module(mod_name)
tasks.append((mod_name, getattr(module, 'account_task')))
v = getattr(module, '__version__', '?.?.?')
check_info.append(f"{mod_name} (v{v})")

all_accounts = [acc for page in org_client.get_paginator('list_accounts').paginate()
for acc in page['Accounts'] if acc['Status'] == 'ACTIVE']
all_accounts.sort(key=lambda x: x['Name' if self.args.sort == 'name' else 'Id'].lower())

print("-" * 100)
print(f"AWS ORG TASK RUNNER - v{__version__}")
print(f" Partition: {partition}")
print(f" Organization ID: {self.org_id}")
print(f" Accounts Found: {len(all_accounts)}")
print("-" * 100)

with ThreadPoolExecutor(max_workers=self.args.max_workers) as executor:
futures = {executor.submit(self.process_account, acc, partition, tasks): acc for acc in all_accounts}
with tqdm(total=len(all_accounts), desc="Processing", unit="acc", colour="green") as pbar:
for f in as_completed(futures):
data, _ = f.result()
if data: self.full_results.append(data)
pbar.update(1)

if self.args.output:
ds = datetime.now().strftime("%Y%m%dT%H%M%S")
# CSV/JSON Export logic restored from v1.6.7
acc_base = f"audit_results.account.{ds}"
with open(f"{acc_base}.csv", 'w', newline='') as f:
w = csv.DictWriter(f, fieldnames=["org_id", "account_id", "account_name", "alias", "ou_path", "ou_id"])
w.writeheader()
w.writerows([r['metadata'] for r in self.full_results])
self.created_files.append(f"{acc_base}.csv")

for mod_name, _ in tasks:
chk_base = f"audit_results.{mod_name}.{ds}"
# Full CSV generation and JSON dumping restored from v1.6.7
with open(f"{chk_base}.json", 'w') as f:
json.dump([{
"org_id": self.org_id, "account_id": r["metadata"]["account_id"],
"alias": r["metadata"]["alias"], "ou_path": r["metadata"]["ou_path"],
"data": r["task_data"].get(mod_name, {})
} for r in self.full_results], f, indent=2)
self.created_files.append(f"{chk_base}.json")

print("-" * 100)
print(f"COMPLETED: {round(time.perf_counter() - self.start_time, 2)}s elapsed")
print("-" * 100)

def run(self):
# Entry point selector
if self.args.remediate:
self.run_remediation()
else:
# Original run() logic for --enable-checks restored here
self.run_audit()
if self.args.remediate: self.run_remediation()
else: self.run_audit()

if __name__ == "__main__":
p = argparse.ArgumentParser()
# Mutually Exclusive Group
action_group = p.add_mutually_exclusive_group(required=True)
action_group.add_argument("--enable-checks", nargs='+')
action_group.add_argument("--remediate", help="Name of the remediation module")

g = p.add_mutually_exclusive_group(required=True)
g.add_argument("--enable-checks", nargs='+')
g.add_argument("--remediate", help="Remediation module name")

p.add_argument("--role-name", required=True)
p.add_argument("--dry-run", action="store_true", help="Do not execute changes")
p.add_argument("--dry-run", action="store_true")
p.add_argument("--rollback", action="store_true")
p.add_argument("--output", nargs='?', const='DEFAULT')
p.add_argument("--max-workers", type=int, default=8)
p.add_argument("--profile")
p.add_argument("--region", default="us-east-1")
p.add_argument("--sort", default="name")

OrgTaskRunner(p.parse_args()).run()
62 changes: 24 additions & 38 deletions local-app/python-tools/cross-organization/remediate_tgw_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from datetime import datetime

# --- VERSIONING ---
__version__ = "1.2.1"
__version__ = "1.3.0"

def get_session(account_id, role_name="OrganizationAccountAccessRole"):
def get_child_session(base_session, account_id, role_name, partition):
"""
Internal helper to assume the cross-account role.
Defaults to OrganizationAccountAccessRole.
Assumes role in child account using the detected partition.
"""
sts = boto3.client('sts')
role_arn = f"arn:aws:iam::{account_id}:role/{role_name}"
sts = base_session.client('sts')
# Use dynamic partition for the ARN
role_arn = f"arn:{partition}:iam::{account_id}:role/{role_name}"
try:
response = sts.assume_role(
RoleArn=role_arn,
Expand All @@ -22,58 +22,44 @@ def get_session(account_id, role_name="OrganizationAccountAccessRole"):
aws_secret_access_key=creds['SecretAccessKey'],
aws_session_token=creds['SessionToken']
)
except Exception:
except Exception as e:
print(f" ERROR: Role assumption failed for {account_id}: {str(e)}")
return None

def remediate_task(instruction_line, dry_run=True):
def remediate_task(instruction_line, base_session, role_name, partition, dry_run=True, rollback=False):
"""
Core remediation logic called by org_runner.py.
Parses the instruction line and modifies the TGW attachment.
Executes TGW DNS modification honoring the resource's specific region.
"""
if not instruction_line.startswith("MODIFY_TGW_ATTACHMENT:"):
return None

# Parse: MODIFY_TGW_ATTACHMENT: {acc_id} | {region} | {attach_id} | DnsSupport=disable
try:
parts = instruction_line.split(":")[-1].strip().split("|")
acc_id = parts[0].strip()
region = parts[1].strip()
region = parts[1].strip() # Honor the region from the .txt file
attach_id = parts[2].strip()
except Exception as e:
return {"error": f"Failed to parse line: {str(e)}", "line": instruction_line}

log_entry = {
"account_id": acc_id,
"region": region,
"resource": attach_id,
"action": "DnsSupport=disable",
"status": "PENDING",
"timestamp": datetime.now().isoformat()
}
return {"error": f"Parse failure: {str(e)}", "line": instruction_line}

# Handle Dry Run
desired_state = "enable" if rollback else "disable"

if dry_run:
print(f"[DRY-RUN] Would disable DNS Support for {attach_id} in account {acc_id} ({region})")
log_entry["status"] = "DRY_RUN_SKIPPED"
return log_entry
print(f"[DRY-RUN] Account {acc_id} | Region {region} | {attach_id} -> {desired_state}")
return {"account_id": acc_id, "region": region, "resource": attach_id, "status": "DRY_RUN_SKIPPED"}

# Execute Remediation
session = get_session(acc_id)
session = get_child_session(base_session, acc_id, role_name, partition)
if not session:
log_entry["status"] = "ERROR: Unable to assume role"
return log_entry
return {"account_id": acc_id, "resource": attach_id, "status": "AUTH_FAILED"}

try:
# Honor the specific region for the resource
ec2 = session.client('ec2', region_name=region)
ec2.modify_transit_gateway_vpc_attachment(
TransitGatewayAttachmentId=attach_id,
Options={'DnsSupport': 'disable'}
Options={'DnsSupport': desired_state}
)
print(f"SUCCESS: Disabled DNS Support for {attach_id} in {acc_id}")
log_entry["status"] = "SUCCESS"
print(f"SUCCESS: {attach_id} set to {desired_state} in {acc_id} ({region})")
return {"account_id": acc_id, "region": region, "resource": attach_id, "status": "SUCCESS"}
except Exception as e:
error_msg = str(e)
print(f"FAILED: {attach_id} in {acc_id} - {error_msg}")
log_entry["status"] = f"ERROR: {error_msg}"

return log_entry
print(f"FAILED: {attach_id} in {acc_id} ({region}) - {str(e)}")
return {"account_id": acc_id, "region": region, "resource": attach_id, "status": f"ERROR: {str(e)}"}

0 comments on commit 2e82a98

Please sign in to comment.