|
| 1 | +import argparse |
| 2 | +import getpass |
| 3 | +import json |
| 4 | +from typing import Optional, Dict |
| 5 | + |
| 6 | +import boto3 |
| 7 | +from botocore.client import BaseClient |
| 8 | + |
| 9 | + |
| 10 | +def get_s3_client(profile: Optional[str], endpoint_url: Optional[str], access_key: Optional[str], secret_key: Optional[str]): |
| 11 | + if profile: |
| 12 | + session = boto3.Session(profile_name=profile) |
| 13 | + return session.client("s3") |
| 14 | + elif endpoint_url and access_key and secret_key: |
| 15 | + return boto3.client("s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key) |
| 16 | + else: |
| 17 | + raise ValueError("You must provide either a profile or endpoint URL and access/secret key.") |
| 18 | + |
| 19 | +def set_tags(s3, bucket: str, key: str, tags: Dict[str, str]) -> bool: |
| 20 | + tag_set = [{'Key': k, 'Value': v} for k, v in tags.items()] |
| 21 | + try: |
| 22 | + s3.put_object_tagging(Bucket=bucket, Key=key, Tagging={'TagSet': tag_set}) |
| 23 | + return True |
| 24 | + except Exception as e: |
| 25 | + print(f"Error tagging object {key}: {e}") |
| 26 | + return False |
| 27 | + |
| 28 | +def tag_files_recursive(s3: BaseClient, bucket: str, key: str, tags: Dict[str, str]): |
| 29 | + paginator = s3.get_paginator('list_objects_v2') |
| 30 | + page_iterator = paginator.paginate(Bucket=bucket, Prefix=key) |
| 31 | + total_tagged = 0 |
| 32 | + for page_number, page in enumerate(page_iterator): |
| 33 | + objects = [obj['Key'] for obj in page.get('Contents', []) if 'Size' in obj] |
| 34 | + print(f"Page {page_number + 1}, Size: {len(objects)}") |
| 35 | + for key in objects: |
| 36 | + if set_tags(s3, bucket, key, tags): |
| 37 | + total_tagged += 1 |
| 38 | + return total_tagged |
| 39 | + |
| 40 | +def main(): |
| 41 | + parser = argparse.ArgumentParser( |
| 42 | + description="Tag objects in S3 bucket" |
| 43 | + ) |
| 44 | + |
| 45 | + parser.add_argument("--profile", type=str, help="AWS profile name") |
| 46 | + parser.add_argument("--endpoint-url", type=str, help="AWS endpoint URL") |
| 47 | + parser.add_argument("--access-key", type=str, help="AWS access key") |
| 48 | + parser.add_argument("--secret-key", type=str, help="AWS secret key") |
| 49 | + parser.add_argument("--bucket", type=str, required=True, help="S3 bucket name") |
| 50 | + parser.add_argument("--prefix", type=str, required=True, help="prefix to tag objects under. E.g. foo/bar/") |
| 51 | + parser.add_argument("--tagging", type=str, required=True, help="tags in format of flat JSON map. E.g. '{\"key\":\"value\"}'") |
| 52 | + |
| 53 | + args = parser.parse_args() |
| 54 | + |
| 55 | + # Ask for endpoint URL, access and secret key if not defined nor profile |
| 56 | + if not args.profile: |
| 57 | + if not args.endpoint_url: |
| 58 | + args.endpoint_url = input("Enter endpoint URL: ") |
| 59 | + if not args.access_key: |
| 60 | + args.access_key = input("Enter access key: ") |
| 61 | + if not args.secret_key: |
| 62 | + args.secret_key = getpass.getpass("Enter secret key: ") |
| 63 | + |
| 64 | + # Parse tags |
| 65 | + tags = json.loads(args.tagging) |
| 66 | + |
| 67 | + # Tag objects |
| 68 | + print("Starting tagging...") |
| 69 | + s3 = get_s3_client(profile=args.profile, endpoint_url=args.endpoint_url, access_key=args.access_key, secret_key=args.secret_key) |
| 70 | + tagged_files = tag_files_recursive(s3=s3, bucket=args.bucket, key=args.prefix, tags=tags) |
| 71 | + print(f"Tagged {tagged_files} files") |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + main() |
0 commit comments