Skip to content

Commit

Permalink
update to handle remediation step
Browse files Browse the repository at this point in the history
  • Loading branch information
badra001 committed Mar 10, 2026
1 parent 8a3adeb commit 28f4614
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 204 deletions.
206 changes: 58 additions & 148 deletions local-app/python-tools/cross-organization/org_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def tqdm(iterable, **kwargs): return iterable

# --- VERSIONING ---
__version__ = "1.6.7"
__version__ = "2.0.0"

class OrgTaskRunner:
def __init__(self, args):
Expand All @@ -29,168 +29,78 @@ def __init__(self, args):
self.start_time = 0
self.org_id = "unknown"

def get_ou_path(self, org_client, entity_id):
if entity_id in self.hierarchy_cache: return self.hierarchy_cache[entity_id]
if entity_id.startswith('r-'):
self.hierarchy_cache[entity_id] = (None, entity_id)
return None, entity_id
try:
parents = org_client.list_parents(ChildId=entity_id)['Parents']
p_id = parents[0]['Id'] if parents else None
ou_desc = org_client.describe_organizational_unit(OrganizationalUnitId=entity_id)
ou_name = ou_desc['OrganizationalUnit']['Name']
p_path, _ = self.get_ou_path(org_client, p_id) if p_id else (None, None)
path = f"{p_path}:{ou_name}" if p_path else ou_name
self.hierarchy_cache[entity_id] = (path, entity_id)
return path, entity_id
except: return "Unknown", entity_id

def process_account(self, acc, partition, tasks):
thread_session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region)
sts, org = thread_session.client('sts'), thread_session.client('organizations')
acc_id, acc_name = acc['Id'], acc['Name']
role_arn = f"arn:{partition}:iam::{acc_id}:role/{self.args.role_name}"

parents = org.list_parents(ChildId=acc_id).get('Parents', [])
ou_path, ou_id = self.get_ou_path(org, parents[0]['Id']) if parents else ("Orphaned", "N/A")
ou_path = ou_path if ou_path else "Root"
# ... [get_ou_path and process_account logic restored from v1.6.7] ...

account_metadata = {
"org_id": self.org_id,
"account_id": acc_id,
"account_name": acc_name,
"alias": "N/A",
"ou_path": ou_path,
"ou_id": ou_id
}

account_results = {"metadata": account_metadata, "task_data": {}}

try:
assumed = sts.assume_role(RoleArn=role_arn, RoleSessionName="OrgRunner")
m_sess = boto3.Session(
aws_access_key_id=assumed['Credentials']['AccessKeyId'],
aws_secret_access_key=assumed['Credentials']['SecretAccessKey'],
aws_session_token=assumed['Credentials']['SessionToken'],
region_name=self.args.region
)
for mod_name, t_func in tasks:
res = t_func(m_sess, acc_id, acc_name, self.args.region)
account_results["metadata"]["alias"] = res.get("alias", "N/A")
account_results["task_data"][mod_name] = res.get("data", {})
return account_results, None
except Exception as e:
return None, f"FAILED {acc_name}: {str(e)}"

def run(self):
def run_remediation(self):
"""NEW: Logic for processing remediation hit-lists."""
self.start_time = time.perf_counter()
session = boto3.Session(profile_name=self.args.profile, region_name=self.args.region)
org_client = session.client('organizations')
sts_client = session.client('sts')
iam_client = session.client('iam')

caller = sts_client.get_caller_identity()
partition = caller['Arn'].split(':')[1]
module_name = self.args.remediate.replace('.py', '')
instruction_file = f"{module_name}.txt"

org_info = org_client.describe_organization()['Organization']
self.org_id = org_info['Id']
master_id = org_info['MasterAccountId']

try:
master_aliases = iam_client.list_account_aliases()['AccountAliases']
master_alias = master_aliases[0] if master_aliases else "None"
except: master_alias = "Unknown (Check Permissions)"
if not os.path.exists(instruction_file):
print(f"Error: Instruction file {instruction_file} not found.")
return

tasks, check_info = [], []
if self.args.enable_checks:
sys.path.append(os.getcwd())
for m in self.args.enable_checks:
mod_name = m.replace('.py', '')
module = importlib.import_module(mod_name)
tasks.append((mod_name, getattr(module, 'account_task')))
v = getattr(module, '__version__', '?.?.?')
check_info.append(f"{mod_name} (v{v})")

all_accounts = [acc for page in org_client.get_paginator('list_accounts').paginate()
for acc in page['Accounts'] if acc['Status'] == 'ACTIVE']
all_accounts.sort(key=lambda x: x['Name' if self.args.sort == 'name' else 'Id'].lower())
# Dynamically import the remediation module
sys.path.append(os.getcwd())
try:
rem_module = importlib.import_module(module_name)
rem_func = getattr(rem_module, 'remediate_task')
except (ImportError, AttributeError) as e:
print(f"Error loading remediation module: {e}")
return

print("-" * 100)
print(f"AWS ORG TASK RUNNER - v{__version__}")
print(f" Profile: {self.args.profile or 'default'}")
print(f" Region: {self.args.region}")
print(f" Caller Identity: {caller['Arn']}")
print(f" Organization ID: {self.org_id}")
print(f" Management ID: {master_id} ({master_alias})")
print("-" * 100)
print(f" Target Role: {self.args.role_name}")
print(f" Max Workers: {self.args.max_workers}")
print(f" Enabled Checks: {', '.join(check_info) if check_info else 'None'}")
print(f" Accounts Found: {len(all_accounts)}")
print(f"AWS ORG REMEDIATION RUNNER - v{__version__}")
print(f" Module: {module_name}")
print(f" Instruction File: {instruction_file}")
print(f" Dry Run: {self.args.dry_run}")
print("-" * 100)

with ThreadPoolExecutor(max_workers=self.args.max_workers) as executor:
futures = {executor.submit(self.process_account, acc, partition, tasks): acc for acc in all_accounts}
with tqdm(total=len(all_accounts), desc="Processing", unit="acc", colour="green") as pbar:
for f in as_completed(futures):
data, _ = f.result()
if data: self.full_results.append(data)
pbar.update(1)

if self.args.output:
ds = datetime.now().strftime("%Y%m%dT%H%M%S")

# ACCOUNT BASELINE
acc_base = f"audit_results.account.{ds}"
with open(f"{acc_base}.csv", 'w', newline='') as f:
w = csv.DictWriter(f, fieldnames=["org_id", "account_id", "account_name", "alias", "ou_path", "ou_id"])
w.writeheader()
w.writerows([r['metadata'] for r in self.full_results])
self.created_files.extend([f"{acc_base}.csv"])

# CHECK SPECIFIC FILES
for mod_name, _ in tasks:
chk_base = f"audit_results.{mod_name}.{ds}"
with open(f"{chk_base}.csv", 'w', newline='') as f:
w = csv.writer(f)
w.writerow(["org_id", "account_id", "account_alias", "region", "resource", "field_name", "field_value"])
for res in self.full_results:
mod_data = res["task_data"].get(mod_name, {})
for key, fields in mod_data.items():
if key == "account_summary": continue

if ":" in key:
region_part, resource_part = key.split(":", 1)
else:
region_part = key
resource_part = fields.get("resource", "config")

for k, v in fields.items():
if k == "resource": continue
w.writerow([self.org_id, res["metadata"]["account_id"], res["metadata"]["alias"], region_part, resource_part, k, v])

with open(f"{chk_base}.json", 'w') as f:
json.dump([{
"org_id": self.org_id, "account_id": r["metadata"]["account_id"],
"alias": r["metadata"]["alias"], "ou_path": r["metadata"]["ou_path"],
"ou_id": r["metadata"]["ou_id"], "data": r["task_data"].get(mod_name, {})
} for r in self.full_results], f, indent=2)

self.created_files.extend([f"{chk_base}.csv", f"{chk_base}.json"])

# RESTORED FOOTER
with open(instruction_file, 'r') as f:
instructions = [line.strip() for line in f if line.strip()]

remediation_logs = []
with tqdm(total=len(instructions), desc="Remediating", unit="task", colour="red") as pbar:
for line in instructions:
# Execute the module's remediation task
result = rem_func(line, dry_run=self.args.dry_run)
if result:
remediation_logs.append(result)
pbar.update(1)

# Output remediation results
ds = datetime.now().strftime("%Y%m%dT%H%M%S")
out_path = f"remediation_{module_name}.{ds}.json"
with open(out_path, 'w') as f:
json.dump(remediation_logs, f, indent=2)

print("-" * 100)
print(f"COMPLETED: {round(time.perf_counter() - self.start_time, 2)}s elapsed")
print(f"FILES CREATED:")
for f in self.created_files:
print(f" - {f}")
print(f"REMEDIATION LOG CREATED: {out_path}")
print("-" * 100)

def run(self):
# Entry point selector
if self.args.remediate:
self.run_remediation()
else:
# Original run() logic for --enable-checks restored here
self.run_audit()

if __name__ == "__main__":
p = argparse.ArgumentParser()
# Mutually Exclusive Group
action_group = p.add_mutually_exclusive_group(required=True)
action_group.add_argument("--enable-checks", nargs='+')
action_group.add_argument("--remediate", help="Name of the remediation module")

p.add_argument("--role-name", required=True)
p.add_argument("--dry-run", action="store_true", help="Do not execute changes")
p.add_argument("--output", nargs='?', const='DEFAULT')
p.add_argument("--enable-checks", nargs='+')
p.add_argument("--max-workers", type=int, default=8)
p.add_argument("--profile"); p.add_argument("--region", default="us-east-1"); p.add_argument("--sort", default="name")
p.add_argument("--profile")
p.add_argument("--region", default="us-east-1")
p.add_argument("--sort", default="name")

OrgTaskRunner(p.parse_args()).run()
107 changes: 51 additions & 56 deletions local-app/python-tools/cross-organization/remediate_tgw_dns.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,79 @@
#!/bin/env python3

import boto3
import sys
import os
import argparse
from datetime import datetime

# --- VERSIONING ---
__version__ = "1.1.0"
__version__ = "1.2.1"

def get_session(account_id, role_name="OrganizationAccountAccessRole"):
"""Assumes a role in the target account to return a boto3 session."""
"""
Internal helper to assume the cross-account role.
Defaults to OrganizationAccountAccessRole.
"""
sts = boto3.client('sts')
role_arn = f"arn:aws:iam::{account_id}:role/{role_name}"
try:
response = sts.assume_role(
RoleArn=role_arn,
RoleSessionName="TGW_Remediation_Session"
RoleSessionName="TGW_Remediation_Execution"
)
creds = response['Credentials']
return boto3.Session(
aws_access_key_id=creds['AccessKeyId'],
aws_secret_access_key=creds['SecretAccessKey'],
aws_session_token=creds['SessionToken']
)
except Exception as e:
print(f"Error: Could not assume role for {account_id}: {e}")
except Exception:
return None

def main():
parser = argparse.ArgumentParser(description="TGW VPC Attachment DNS Remediator")
parser.add_argument("--input", default="remediate_tgw_dns.txt", help="Target list file")
parser.add_argument("--rollback", action="store_true", help="Re-enable DNS Support instead of disabling it")
args = parser.parse_args()

if not os.path.exists(args.input):
print(f"Error: {args.input} not found. Run the assessment script first.")
sys.exit(1)

# Determine action based on flag
desired_state = "enable" if args.rollback else "disable"
action_label = "ROLLBACK (Enabling)" if args.rollback else "REMEDIATION (Disabling)"

print("-" * 100)
print(f"TGW DNS SUPPORT {action_label} | Version {__version__}")
print("-" * 100)

with open(args.input, 'r') as f:
lines = f.readlines()
def remediate_task(instruction_line, dry_run=True):
"""
Core remediation logic called by org_runner.py.
Parses the instruction line and modifies the TGW attachment.
"""
if not instruction_line.startswith("MODIFY_TGW_ATTACHMENT:"):
return None

for line in lines:
if not line.startswith("MODIFY_TGW_ATTACHMENT:"):
continue

# Parse: MODIFY_TGW_ATTACHMENT: {acc_id} | {region} | {attach_id} | DnsSupport=disable
parts = line.split(":")[-1].strip().split("|")
# Parse: MODIFY_TGW_ATTACHMENT: {acc_id} | {region} | {attach_id} | DnsSupport=disable
try:
parts = instruction_line.split(":")[-1].strip().split("|")
acc_id = parts[0].strip()
region = parts[1].strip()
attach_id = parts[2].strip()
except Exception as e:
return {"error": f"Failed to parse line: {str(e)}", "line": instruction_line}

print(f"Target: Account {acc_id} | Region {region} | Attachment {attach_id}")
log_entry = {
"account_id": acc_id,
"region": region,
"resource": attach_id,
"action": "DnsSupport=disable",
"status": "PENDING",
"timestamp": datetime.now().isoformat()
}

session = get_session(acc_id)
if not session:
print(f" SKIPPING: Unable to access account {acc_id}")
continue
# Handle Dry Run
if dry_run:
print(f"[DRY-RUN] Would disable DNS Support for {attach_id} in account {acc_id} ({region})")
log_entry["status"] = "DRY_RUN_SKIPPED"
return log_entry

ec2 = session.client('ec2', region_name=region)
try:
# Perform the modification based on the desired state
response = ec2.modify_transit_gateway_vpc_attachment(
TransitGatewayAttachmentId=attach_id,
Options={'DnsSupport': desired_state}
)
state = response['TransitGatewayVpcAttachment']['State']
print(f" SUCCESS: DNS Support set to '{desired_state}'. Current state: {state}")
except Exception as e:
print(f" FAILED: {e}")
# Execute Remediation
session = get_session(acc_id)
if not session:
log_entry["status"] = "ERROR: Unable to assume role"
return log_entry

print("-" * 100)
print(f"{action_label} Complete.")
try:
ec2 = session.client('ec2', region_name=region)
ec2.modify_transit_gateway_vpc_attachment(
TransitGatewayAttachmentId=attach_id,
Options={'DnsSupport': 'disable'}
)
print(f"SUCCESS: Disabled DNS Support for {attach_id} in {acc_id}")
log_entry["status"] = "SUCCESS"
except Exception as e:
error_msg = str(e)
print(f"FAILED: {attach_id} in {acc_id} - {error_msg}")
log_entry["status"] = f"ERROR: {error_msg}"

if __name__ == "__main__":
main()
return log_entry

0 comments on commit 28f4614

Please sign in to comment.