Back to Notes

YOLO Data Conversion and Augmentation

YOLO
Data Preprocessing
Computer Vision
Data Augmentation
2/10/2024
12 min

💡You can run the below code in Kaggle by using the copy and edit buttons. Kaggle Notebook Link

This is a notebook to do the following:

  • Data Format conversion from the YOLO OBB to the YOLO format
  • Data Format conversion from the YOLO format to the YOLO OBB format
  • Data augmentation in YOLO Format using the Albumentation package.

It is better if you run the code sequentially

Data Format Conversion Yolo OBB to Yolo Format

#Importing the libraries
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import math
import shutil

The below function 'obb_to_yolo' converts the yolo obb format annotation to yolo format annotation for the datum

def obb_to_yolo(class_index, x1, y1, x2, y2, x3, y3, x4, y4):
    # Find min and max for x and y among all four corner points
    xmin = min(x1, x2, x3, x4)
    xmax = max(x1, x2, x3, x4)
    ymin = min(y1, y2, y3, y4)
    ymax = max(y1, y2, y3, y4)
    
    # Calculate the axis-aligned bounding box (AABB)
    x_center = (xmin + xmax) / 2
    y_center = (ymin + ymax) / 2
    width = xmax - xmin
    height = ymax - ymin
    
    # Create the YOLO format annotation
    yolo_annotation = f"{int(class_index)} {x_center} {y_center} {width} {height}"
    return yolo_annotation




!mkdir /kaggle/working/test
def convert_obb_to_yolo(input_file_path, output_directory):
    # Get the base name without the extension and the name of the input file
    base_name = os.path.basename(input_file_path)
    file_name_without_extension = os.path.splitext(base_name)[0]

    # Ensure the output directory exists
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # Output file has the same name as the input file, saved in the output directory with '_yolo' appended
    output_file_path = os.path.join(output_directory, f"{file_name_without_extension}_yolo.txt")

    # Read the input file, convert each line, and write to output file
    with open(input_file_path, 'r') as input_file, open(output_file_path, 'w') as output_file:
        for line in input_file:
            parts = line.strip().split()
            if len(parts) != 9:
                raise ValueError("Each line must contain 9 values: class_index, x1, y1, x2, y2, x3, y3, x4, y4.")
            # Convert string parts to float and unpack them
            class_index, x1, y1, x2, y2, x3, y3, x4, y4 = map(float, parts)
            # Call the conversion function
            yolo_line = obb_to_yolo(class_index, x1, y1, x2, y2, x3, y3, x4, y4)
            # Write the result to the output file
            output_file.write(yolo_line + '\n')
            print(f"Converted line to YOLO format: {yolo_line}")

    print(f"Conversion complete. Output file saved to: {output_file_path}")


input_obb_annotations_file = '/kaggle/input/yolo-obb/alexandrite_11_jpg.rf.1052aaef4bac9e1051010f35dabc5e87.txt'
output_directory = '/kaggle/working/test'
convert_obb_to_yolo(input_obb_annotations_file, output_directory)
!cat /kaggle/input/face-mask-dataset-yolo-format/dataset/images/test/NUZZHAB7IMI6VHQQW44IVOMBHU.txt
1 0.49528571428571433 0.5424346335190742 0.10085714285714287 0.143591941705958

Data Format Conversion Yolo Format to Yolo OBB Format

def yolo_to_obb(class_index, x_center, y_center, width, height):
    # Calculate the coordinates of the four corner points of the OBB
    x1 = x_center - width / 2
    y1 = y_center - height / 2
    x2 = x_center + width / 2
    y2 = y_center - height / 2
    x3 = x_center + width / 2
    y3 = y_center + height / 2
    x4 = x_center - width / 2
    y4 = y_center + height / 2
    
    # Create the OBB format annotation
    obb_annotation = f"{class_index} {x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4}"
    return obb_annotation

def convert_yolo_to_obb(input_file_path, output_directory):
    base_name = os.path.basename(input_file_path)
    file_name_without_extension = os.path.splitext(base_name)[0]

    # Ensure the output directory exists
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # Output file has the same name as the input file, saved in the output directory with '_obb' appended
    output_file_path = os.path.join(output_directory, f"{file_name_without_extension}_obb.txt")

    # Read the input file, convert each line, and write to output file
    with open(input_file_path, 'r') as input_file, open(output_file_path, 'w') as output_file:
        for line in input_file:
            parts = line.strip().split()
            if len(parts) != 5:
                raise ValueError("Each line must contain 5 values: class_index, x_center, y_center, width, height.")
            # Convert string parts to float and unpack them
            class_index, x_center, y_center, width, height = map(float, parts)
            # Call the conversion function
            obb_line = yolo_to_obb(class_index, x_center, y_center, width, height)
            # Write the result to the output file
            output_file.write(obb_line + '\n')
            print(f"Converted line to OBB format: {obb_line}")

    print(f"Conversion complete. Output file saved to: {output_file_path}")


    
input_yolo_annotations_file = '/kaggle/input/face-mask-dataset-yolo-format/dataset/images/test/NUZZHAB7IMI6VHQQW44IVOMBHU.txt'
output_directory = '/kaggle/working/test'
convert_yolo_to_obb(input_yolo_annotations_file, output_directory)
Converted line to OBB format: 1.0 0.4448571428571429 0.47063866266609516 0.5457142857142858 0.47063866266609516 0.5457142857142858 0.6142306043720531 0.4448571428571429 0.6142306043720531
Conversion complete. Output file saved to: /kaggle/working/test/NUZZHAB7IMI6VHQQW44IVOMBHU_obb.txt
!cat /kaggle/working/test/NUZZHAB7IMI6VHQQW44IVOMBHU_obb.txt
1.0 0.4448571428571429 0.47063866266609516 0.5457142857142858 0.47063866266609516 0.5457142857142858 0.6142306043720531 0.4448571428571429 0.6142306043720531

Visualizing the image with the bounding box annotations in obb format

def draw_obb_bounding_box(image, obb_annotation):
    # Split the annotation into parts
    parts = obb_annotation.strip().split()
    if len(parts) != 9:
        print(f"Wrong annotation format: {obb_annotation}")
        return image  # Skip this annotation if the format is wrong

    # Convert to floating point numbers
    class_index, x1, y1, x2, y2, x3, y3, x4, y4 = map(float, parts)

    # Denormalize coordinates if necessary
    height, width = image.shape[:2]
    points = [(x1 * width, y1 * height), (x2 * width, y2 * height),
              (x3 * width, y3 * height), (x4 * width, y4 * height)]

    # Draw the bounding box as lines between each corner point
    color = (255, 0, 0)  # Blue color in BGR
    thickness = 2
    num_points = len(points)
    for i in range(num_points):
        pt1 = (int(points[i][0]), int(points[i][1]))
        pt2 = (int(points[(i + 1) % num_points][0]), int(points[(i + 1) % num_points][1]))
        cv2.line(image, pt1, pt2, color, thickness)

    return image

def visualize_obb(image_path, obb_annotations_file):
    # Read the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Unable to read the image file {image_path}")
        return

    print(f"Image dimensions: {image.shape}")

    # Read the OBB annotations file
    with open(obb_annotations_file, 'r') as file:
        annotations = file.readlines()

    if not annotations:
        print(f"No annotations found in file {obb_annotations_file}")
        return

    # Draw each OBB annotation on the image
    for annotation in annotations:
        image = draw_obb_bounding_box(image, annotation)

    # Convert to RGB and plot the image using matplotlib
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(10, 10))
    plt.imshow(image_rgb)
    plt.axis('off')  # Hide the axis
    plt.show()
# Provide the input file path of the image and the annotations
image_path = '/kaggle/input/face-mask-dataset-yolo-format/dataset/images/test/NUZZHAB7IMI6VHQQW44IVOMBHU.jpg'  
obb_annotations_file = '/kaggle/working/test/NUZZHAB7IMI6VHQQW44IVOMBHU_obb.txt' 
visualize_obb(image_path, obb_annotations_file)

Image dimensions: (2333, 3500, 3)



Image

Visualizing the yolo bounding box with the image to check if it is correctly generated

def draw_yolo_bounding_boxes(image_path, annotations_path):
    # Read the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Unable to read image file {image_path}")
        return
    
    # Convert to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    height, width, _ = image.shape
    
    # Start plot
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    ax = plt.gca()
    
    # Read annotations
    with open(annotations_path, 'r') as file:
        for line in file.readlines():
            parts = line.strip().split()
            if len(parts) == 5:
                cls, x_center, y_center, w, h = map(float, parts)
                
                # Denormalize coordinates
                x_center, y_center, w, h = (x_center * width, y_center * height, w * width, h * height)
                x_min, y_min = int(x_center - w / 2), int(y_center - h / 2)
                
                # Create a Rectangle patch
                rect = patches.Rectangle((x_min, y_min), w, h, linewidth=1, edgecolor='r', facecolor='none')
                
                # Add the patch to the Axes
                ax.add_patch(rect)
    
    plt.axis('off')  # Hide axis
    plt.show()

# Example usage
image_path = '/kaggle/input/yolo-obb/alexandrite_11_jpg.rf.1052aaef4bac9e1051010f35dabc5e87.jpg'  # Replace with the path to your image
annotations_path = '/kaggle/working/test/alexandrite_11_jpg.rf.1052aaef4bac9e1051010f35dabc5e87_yolo.txt'  # Replace with the path to your YOLO annotation file
draw_yolo_bounding_boxes(image_path, annotations_path)

Pipeline to format the yolo obb dataset annotations inside directories to yolo format dataset annotations in respective directories

# modifying the convert obb to yolo function for the output path file
def convert_obb_to_yolo(input_obb_file, output_yolo_file, img_width, img_height):
    with open(input_obb_file, 'r') as input_file, open(output_yolo_file, 'w') as output_file:
        for line in input_file:
            parts = line.strip().split()
            if len(parts) == 9:
                class_index, x1, y1, x2, y2, x3, y3, x4, y4 = map(float, parts)
                # Convert annotation to YOLO format
                yolo_annotation = obb_to_yolo(class_index, x1, y1, x2, y2, x3, y3, x4, y4, img_width, img_height)
                output_file.write(yolo_annotation + '\n')
def process_directory(input_dir, output_dir):
    # Recursively process files in input directory and convert annotations
    for root, dirs, files in os.walk(input_dir):
        # Create corresponding structure in output_dir
        rel_path = os.path.relpath(root, input_dir)
        current_output_dir = os.path.join(output_dir, rel_path)
        os.makedirs(current_output_dir, exist_ok=True)

        for file in files:
            # Check if the file is an annotation file (assumes .txt extension)
            if file.endswith('.txt'):
                input_obb_file = os.path.join(root, file)
                output_yolo_file = os.path.join(current_output_dir, file)

                # Assume image with the same name (but different extension) exists
                img_name, _ = os.path.splitext(file)
                img_file = None
                for ext in ['.jpg', '.png', '.jpeg']:  # Add other image extensions if necessary
                    if os.path.exists(os.path.join(root, img_name + ext)):
                        img_file = os.path.join(root, img_name + ext)
                        break

                if img_file:
                    # Read image to get width and height
                    img = cv2.imread(img_file)
                    if img is not None:
                        height, width = img.shape[:2]
                        # Convert the annotation file
                        convert_obb_to_yolo(input_obb_file, output_yolo_file, width, height)
#Change the subdirectory name if required 
def main(input_root_dir, output_root_dir):
    # Process each subdirectory ('train', 'test', 'valid')
    for subdir in ['train', 'test', 'valid']:
        input_dir = os.path.join(input_root_dir, subdir)
        if os.path.exists(input_dir):
            process_directory(input_dir, output_root_dir)
input_root_dir = '/path/to/root/input/directory'  # Replace with your input root directory path
output_root_dir = '/path/to/root/output/directory'  # Replace with your desired output root directory path
main(input_root_dir, output_root_dir) 

Pipeline to format the yolo dataset annotations inside directories to yolo obb format dataset annotations in respective directories

def yolo_to_yolo_obb(class_index, x_center, y_center, width, height):
    # Calculate the coordinates of the four corner points of the OBB
    x1 = x_center - width / 2
    y1 = y_center - height / 2
    x2 = x_center + width / 2
    y2 = y_center - height / 2
    x3 = x_center + width / 2
    y3 = y_center + height / 2
    x4 = x_center - width / 2
    y4 = y_center + height / 2
    
    # Create the YOLO OBB format annotation
    yolo_obb_annotation = f"{class_index} {x_center} {y_center} {width} {height} {x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4}"
    return yolo_obb_annotation

def process_directory(input_dir, output_dir):
    # Recursively process files in input directory and convert annotations
    for root, dirs, files in os.walk(input_dir):
        # Create corresponding structure in output_dir
        rel_path = os.path.relpath(root, input_dir)
        current_output_dir = os.path.join(output_dir, rel_path)
        os.makedirs(current_output_dir, exist_ok=True)

        for file in files:
            # Check if the file is an annotation file (assumes .txt extension)
            if file.endswith('.txt'):
                input_yolo_file = os.path.join(root, file)
                output_yolo_obb_file = os.path.join(current_output_dir, file)

                # Assume image with the same name (but different extension) exists
                img_name, _ = os.path.splitext(file)
                img_file = None
                for ext in ['.jpg', '.png', '.jpeg']:  # Add other image extensions if necessary
                    if os.path.exists(os.path.join(root, img_name + ext)):
                        img_file = os.path.join(root, img_name + ext)
                        break

                if img_file:
                    # Read image to get width and height
                    img = cv2.imread(img_file)
                    if img is not None:
                        height, width = img.shape[:2]
                        # Convert the annotation file
                        convert_yolo_to_obb(input_yolo_file, output_yolo_obb_file, width, height)

#Change the subdirectory name if required 
def main(input_root_dir, output_root_dir):
    # Process each subdirectory ('train', 'test', 'valid')
    for subdir in ['train', 'test', 'valid']:
        input_dir = os.path.join(input_root_dir, subdir)
        if os.path.exists(input_dir):
            process_directory(input_dir, output_root_dir)
input_root_dir = '/path/to/root/input/directory'  # Replace with your input root directory path
output_root_dir = '/path/to/root/output/directory'  # Replace with your desired output root directory path
main(input_root_dir, output_root_dir) 

Augmentation

  • The below code is for doing augmentation on the single image with bounding box
#Importing the libraries
import os
import shutil
from tqdm import tqdm
from tqdm import tqdm
import os
import numpy as np
import cv2
from albumentations.pytorch import ToTensorV2
import albumentations as A
# Define the augmentations
augmentations = A.Compose(
    [
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30, p=0.5),
        A.RandomResizedCrop(height=416, width=416, p=1),
        ToTensorV2(p=1.0),  # Converts image to pytorch tensor and scales it to [0,1]
    ],
    bbox_params=A.BboxParams(format='yolo', min_area=1024, min_visibility=0.3, label_fields=['labels'])
)
#Loading the image and the bounding boxes
def load_image_and_bboxes(image_path, bboxes_path):
    # Load the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Load the bounding boxes
    bboxes = []
    labels = []
    with open(bboxes_path, 'r') as file:
        for line in file:
            elements = line.strip().split()
            class_label, x_center, y_center, width, height = map(float, elements)
            labels.append(int(class_label))
            bboxes.append([x_center, y_center, width, height]) 

    return image, bboxes, labels

# Apply augmentations
def augment(image, bboxes, labels):
    augmented = augmentations(image=image, bboxes=bboxes, labels=labels)
    return augmented['image'], augmented['bboxes'], augmented['labels']

# Input image path and the bounding box path
image_path = '/kaggle/input/face-mask-dataset-yolo-format/dataset/images/train/-1x-1.jpg'
bboxes_path = '/kaggle/input/face-mask-dataset-yolo-format/dataset/images/train/-1x-1.txt'

image, bboxes, labels = load_image_and_bboxes(image_path, bboxes_path)
augmented_image, augmented_bboxes, augmented_labels = augment(image, bboxes, labels)

# Function to convert YOLO bboxes to matplotlib format
def yolo_to_mpl_bbox(bbox, image_size):
    img_width, img_height = image_size
    x_center, y_center, width, height = bbox
    x_min = (x_center - width / 2) * img_width
    y_min = (y_center - height / 2) * img_height
    bbox_width = width * img_width
    bbox_height = height * img_height
    return x_min, y_min, bbox_width, bbox_height

# Function to plot an image and draw the bounding boxes
def plot_image_with_bboxes(image_np, bboxes, ax):
    # Convert numpy image array to matplotlib format
    ax.imshow(image_np)
    
    # Draw bounding boxes
    for bbox in bboxes:
        x_min, y_min, bbox_width, bbox_height = yolo_to_mpl_bbox(bbox, image_np.shape[1::-1])
        rect = patches.Rectangle((x_min, y_min), bbox_width, bbox_height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

# Load and augment the image and bboxes
image, bboxes, labels = load_image_and_bboxes(image_path, bboxes_path)
augmented_image, augmented_bboxes, augmented_labels = augment(image, bboxes, labels)

# Convert augmented_image (tensor) to numpy for visualization
augmented_image_np = augmented_image.mul(255).permute(1, 2, 0).byte().numpy()

# Create the subplot for original and augmented images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

# Plot the original image with bounding boxes
plot_image_with_bboxes(image, bboxes, ax1)
ax1.set_title('Original Image')

# Plot the augmented image with bounding boxes
plot_image_with_bboxes(augmented_image_np, augmented_bboxes, ax2)
ax2.set_title('Augmented Image')

# Display the images
plt.show()

The below code is for creating the augmentation on the yolo format dataset (Train, Test and Valid)

# Define the augmentations 
augmentations = A.Compose(
    [
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30, p=0.5),
        A.RandomResizedCrop(height=416, width=416, p=1),
        ToTensorV2(p=1.0),  # Converts image to pytorch tensor and scales it to [0,1]
    ],
    bbox_params=A.BboxParams(format='yolo', min_area=1024, min_visibility=0.3, label_fields=['labels'])
)
# Function to process and augment a dataset directory (train, test or val)
def process_dataset(dataset_dir, output_dir):
    # Get a list of all image files and corresponding annotation files
    image_files = [f for f in os.listdir(dataset_dir) if os.path.splitext(f)[1].lower() in ['.jpg', '.png', '.jpeg']]
    
    for image_filename in tqdm(image_files, desc=f"Processing {output_dir}"):
        image_path = os.path.join(dataset_dir, image_filename)
        annotation_filename = os.path.splitext(image_filename)[0] + '.txt'
        annotation_path = os.path.join(dataset_dir, annotation_filename)
        
        # Load image and bounding boxes
        image, bboxes, labels = load_image_and_bboxes(image_path, annotation_path)
        
        # Augment the image and the bounding boxes
        augmented = augmentations(image=image, bboxes=bboxes, labels=labels)
        augmented_image, augmented_bboxes, _ = augmented['image'], augmented['bboxes'], augmented['labels']

        # Convert the augmented image to a numpy array with uint8 type (assuming it's already in the [0, 255] range)
        if isinstance(augmented_image, np.ndarray):
            image_to_save = augmented_image
        else:
            image_to_save = augmented_image.numpy().astype(np.uint8)
            if image_to_save.shape[0] == 3:  # If the image has channels-first format (C, H, W)
                # Convert the tensor to channels-last format (H, W, C)
                image_to_save = image_to_save.transpose(1, 2, 0)

        # Define the output paths
        output_image_path = os.path.join(output_dir, image_filename)
        output_annotation_path = os.path.join(output_dir, annotation_filename)
            
        # Save the augmented image using OpenCV
        cv2.imwrite(output_image_path, image_to_save)

        # Save the augmented bounding boxes (in YOLO format)
        with open(output_annotation_path, 'w') as f:
            for bbox in augmented_bboxes:
                class_id = int(bbox[0])  # Assuming the first element is the class ID
                bbox_str = ' '.join(map(str, bbox))
                f.write(f"{class_id} {bbox_str}\n")


# Create output folders for train, test, and val
dataset_root = "/kaggle/input/face-mask-dataset-yolo-format/dataset/images"
output_root = "/kaggle/working/augmentation"
for dataset_type in ['train', 'test', 'valid']:
    dataset_dir = os.path.join(dataset_root, dataset_type)
    output_dir = os.path.join(output_root, dataset_type)
    os.makedirs(output_dir, exist_ok=True)

    # Process dataset
    process_dataset(dataset_dir, output_dir)

dir_path= "/kaggle/working/augmentation/train"
print(len([entry for entry in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, entry))]))

from IPython.display import Image

#Image(filename=f"{HOME}/runs/train/exp/results.png", width=1000)

Image(filename=f"/kaggle/working/augmentation/valid/images19.jpg", width=600)
from IPython.display import Image

#Image(filename=f"/kaggle/working/augmentation/valid/images19.jpg", width=600)

Image(filename=f"/kaggle/input/face-mask-dataset-yolo-format/dataset/images/valid/images19.jpg", width=600)
# Used for removing the directories wrongly created
!rm -r /kaggle/working/*

The code provided facilitates the augmentation of a single image by applying a variety of transformations. For each specified augmentation type, it generates multiple unique augmented versions of the original image. Each resulting image bears a distinct filename that incorporates an identifying tag corresponding to the applied augmentation, effectively distinguishing between different augmented versions. The augmentation process adheres to the predefined logic without alteration, ensuring that the core functionality remains intact.

Make sure that you are running the load_image_and_bboxes implemented below not the above one otherwise this will throw and error

def load_image_and_bboxes(image_path, annotation_path):
    # Check the existence of the annotation file
    if not os.path.isfile(annotation_path):
        print(f"Annotation file does not exist: {annotation_path}")
        return None, [], []

    # Read the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Failed to read the image: {image_path}")
        return None, [], []

    # Convert BGR to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Initialize bounding boxes and labels lists
    bboxes, labels = [], []

    # Read and parse the YOLO annotation file.
    with open(annotation_path, "r") as file:
        for line in file:
            elements = line.strip().split()
            if len(elements) == 5:
                class_id, x_center, y_center, width, height = map(float, elements)
                labels.append(int(class_id))
                bboxes.append([x_center, y_center, width, height])  # YOLO format

    return image, bboxes, labels

# Define individual augmentations
vertical_flip = A.VerticalFlip(p=1)
random_bright_contrast = A.RandomBrightnessContrast(p=1)
horizontal_flip = A.HorizontalFlip(p=1)
shift_scale_rotate = A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0)
rgb_shift = A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30, p=1)
random_resized_crop = A.RandomResizedCrop(height=416, width=416, p=1)

# Group all augmentations in a list
augmentations_list = [
    ('vflip', vertical_flip),
    ('bright_contrast', random_bright_contrast),
    ('hflip', horizontal_flip),
    ('shift_scale_rotate', shift_scale_rotate),
    ('rgb_shift', rgb_shift),
    ('resized_crop', random_resized_crop),
]
# Function to process and augment a dataset directory (train, test or val)
def process_dataset(dataset_dir, output_dir):
    # Get a list of all image files and corresponding annotation files
    image_files = [f for f in os.listdir(dataset_dir) if os.path.splitext(f)[1].lower() in ['.jpg', '.jpeg']]
    
    for image_filename in tqdm(image_files, desc=f"Processing {output_dir}"):
        image_path = os.path.join(dataset_dir, image_filename)
        annotation_filename = os.path.splitext(image_filename)[0] + '.txt'
        annotation_path = os.path.join(dataset_dir, annotation_filename)
        
        # Load image and bounding boxes
        image, bboxes, labels = load_image_and_bboxes(image_path, annotation_path)
        
        # Apply each augmentation separately and save the results
        for aug_name, aug in augmentations_list:
            augmented = aug(image=image, bboxes=bboxes, labels=labels)
            augmented_image = augmented['image']
            augmented_bboxes = augmented['bboxes']

            # Add a tag for the augmentation
            filename_without_ext, ext = os.path.splitext(image_filename)
            new_image_filename = f"{filename_without_ext}_{aug_name}{ext}"
            new_annotation_filename = f"{filename_without_ext}_{aug_name}.txt"

            output_image_path = os.path.join(output_dir, new_image_filename)
            output_annotation_path = os.path.join(output_dir, new_annotation_filename)

            save_augmented_data(augmented_image, augmented_bboxes, output_image_path, output_annotation_path)

# Save the augmented image and bounding boxes
def save_augmented_data(image, bboxes, image_path, annotation_path):
    # Handle conversion to uint8 if needed
    if not isinstance(image, np.ndarray):
        image = image.numpy().astype(np.uint8)
        if image.shape[0] == 3:  # Channels-first format
            image = image.transpose(1, 2, 0)  # Convert to channels-last format
    
    # Save the image
    cv2.imwrite(image_path, image[:, :, ::-1])  # Convert RGB to BGR

    # Save the bboxes
    class_id = int(bboxes[0][0])  # Assuming the class ID is the first element in bbox
    with open(annotation_path, 'w') as file:
        for bbox in bboxes:
            bbox_str = ' '.join(map(str, bbox))
            file.write(f"{class_id} {bbox_str}\n")
# Define the paths and start the process.
dataset_root = "/kaggle/input/face-mask-dataset-yolo-format/dataset/images"
output_root = "/kaggle/working/augmentation"
for dataset_type in ['train', 'test', 'valid']:
    dataset_dir = os.path.join(dataset_root, dataset_type)
    output_dir = os.path.join(output_root, dataset_type)
    os.makedirs(output_dir, exist_ok=True)
    
    # Process the dataset
    process_dataset(dataset_dir, output_dir)
print("--------------------Done---------------------")
# For checking the no of files generated inside the output directory
dir_path= "/kaggle/working/augmentation/train"
print(len([entry for entry in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, entry))]))