Files
gobot/board-vision/cnn_model/fileSetUtils.py
2024-08-20 01:15:43 +02:00

75 lines
2.4 KiB (Stored with Git LFS)
Python

import os
import random
import glob
from io import BytesIO
from zipfile import ZipFile
import cv2
import numpy as np
import pandas as pd
def decompress(f: str) -> pd.DataFrame:
with ZipFile(f, 'r') as z:
filelist = list(map(lambda x: x.filename, z.filelist))
random.seed(42)
random.shuffle(filelist)
images = []
label = []
name = []
for i in filelist:
images.append(cv2.imdecode(np.frombuffer(z.read(i), dtype=np.uint8), flags=1))
label.append(1 if os.path.basename(i)[0] == "P" in i else -1)
name.append(os.path.basename(i))
return pd.DataFrame({'images': images, 'label': label, 'name': name})
def compress(f: str, input_folder: str):
negative_folder = os.path.join(os.path.abspath(input_folder), "negative")
positive_folder = os.path.join(os.path.abspath(input_folder), "positive")
# Check if negative is folder, if not raise error
if not os.path.isdir(negative_folder):
raise FileNotFoundError(negative_folder)
if not os.path.isdir(positive_folder):
raise FileNotFoundError(positive_folder)
with ZipFile(f, 'w') as z:
counter = 0
negative_files = glob.glob(negative_folder + '/**/*.png', recursive=True)
positive_files = glob.glob(positive_folder + '/**/*.png', recursive=True)
for file in negative_files:
z.write(file, f"N_{counter}.png")
counter += 1
for file in positive_files:
z.write(file, f"P_{counter}.png")
counter += 1
if __name__ == "__main__":
# parse arguments
import argparse
parser = argparse.ArgumentParser(description='Compress a folder')
parser.add_argument('input_folder', type=str, help='Folder to compress')
parser.add_argument('output_file', type=str, help='Output file name')
parser.add_argument("--peek", action="store_true", help="Peek the content of the compressed file")
args = parser.parse_args()
if args.peek:
pd.set_option('display.max_rows', 500)
df = decompress(args.output_file)
print(df.head(15))
print()
print("- {:07d}".format(df.query("label == -1").shape[0]))
print("+ {:07d}".format(df.query("label == 1").shape[0]))
print(" {:07d}".format(df.shape[0]))
else:
compress(args.output_file, args.input_folder)