# ------------------------------------------------------------------------------ # Copyright (c) ETRI. All rights reserved. # Licensed under the BSD 3-Clause License. # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. # You can refer to details of AIR project at https://aiforrobots.github.io # Written by Youngwoo Yoon (youngwoo@etri.re.kr) # ------------------------------------------------------------------------------ import glob import matplotlib import cv2 import re import json import _pickle as pickle from webvtt import WebVTT from config import my_config ############################################################################### # SKELETON def draw_skeleton_on_image(img, skeleton, thickness=15): if not skeleton: return img new_img = img.copy() for pair in SkeletonWrapper.skeleton_line_pairs: pt1 = (int(skeleton[pair[0] * 3]), int(skeleton[pair[0] * 3 + 1])) pt2 = (int(skeleton[pair[1] * 3]), int(skeleton[pair[1] * 3 + 1])) if pt1[0] == 0 or pt2[1] == 0: pass else: rgb = [v * 255 for v in matplotlib.colors.to_rgba(pair[2])][:3] cv2.line(new_img, pt1, pt2, color=rgb[::-1], thickness=thickness) return new_img def is_list_empty(my_list): return all(map(is_list_empty, my_list)) if isinstance(my_list, list) else False def get_closest_skeleton(frame, selected_body): """ find the closest one to the selected skeleton """ diff_idx = [i * 3 for i in range(8)] + [i * 3 + 1 for i in range(8)] # upper-body min_diff = 10000000 tracked_person = None for person in frame: # people body = get_skeleton_from_frame(person) diff = 0 n_diff = 0 for i in diff_idx: if body[i] > 0 and selected_body[i] > 0: diff += abs(body[i] - selected_body[i]) n_diff += 1 if n_diff > 0: diff /= n_diff if diff < min_diff: min_diff = diff tracked_person = person base_distance = max(abs(selected_body[0 * 3 + 1] - selected_body[1 * 3 + 1]) * 3, abs(selected_body[2 * 3] - selected_body[5 * 3]) * 2) if tracked_person and min_diff > base_distance: # tracking failed tracked_person = None return tracked_person def get_skeleton_from_frame(frame): if 'pose_keypoints_2d' in frame: return frame['pose_keypoints_2d'] elif 'pose_keypoints' in frame: return frame['pose_keypoints'] else: return None class SkeletonWrapper: # color names: https://matplotlib.org/mpl_examples/color/named_colors.png visualization_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'gold'), (1, 5, 'darkgreen'), (5, 6, 'g'), (6, 7, 'lightgreen'), (1, 8, 'darkcyan'), (8, 9, 'c'), (9, 10, 'skyblue'), (1, 11, 'deeppink'), (11, 12, 'hotpink'), (12, 13, 'lightpink')] skeletons = [] skeleton_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'gold'), (1, 5, 'darkgreen'), (5, 6, 'g'), (6, 7, 'lightgreen')] def __init__(self, basepath, vid): # load skeleton data (and save it to pickle for next load) pickle_file = glob.glob(basepath + '/' + vid + '.pickle') if pickle_file: with open(pickle_file[0], 'rb') as file: self.skeletons = pickle.load(file) else: files = glob.glob(basepath + '/' + vid + '/*.json') if len(files) > 10: files = sorted(files) self.skeletons = [] for file in files: self.skeletons.append(self.read_skeleton_json(file)) with open(basepath + '/' + vid + '.pickle', 'wb') as file: pickle.dump(self.skeletons, file) else: self.skeletons = [] def read_skeleton_json(self, file): with open(file) as json_file: skeleton_json = json.load(json_file) return skeleton_json['people'] def get(self, start_frame_no, end_frame_no, interval=1): chunk = self.skeletons[start_frame_no:end_frame_no] if is_list_empty(chunk): return [] else: if interval > 1: return chunk[::int(interval)] else: return chunk ############################################################################### # VIDEO def read_video(base_path, vid): files = glob.glob(base_path + '/*' + vid + '.mp4') if len(files) == 0: return None elif len(files) >= 2: assert False filepath = files[0] video_obj = VideoWrapper(filepath) return video_obj class VideoWrapper: video = [] def __init__(self, filepath): self.filepath = filepath self.video = cv2.VideoCapture(filepath) self.total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) self.height = self.video.get(cv2.CAP_PROP_FRAME_HEIGHT) self.framerate = self.video.get(cv2.CAP_PROP_FPS) def get_video_reader(self): return self.video def frame2second(self, frame_no): return frame_no / self.framerate def second2frame(self, second): return int(round(second * self.framerate)) def set_current_frame(self, cur_frame_no): self.video.set(cv2.CAP_PROP_POS_FRAMES, cur_frame_no) ############################################################################### # CLIP def load_clip_data(vid): try: with open("{}/{}.json".format(my_config.CLIP_PATH, vid)) as data_file: data = json.load(data_file) return data except FileNotFoundError: return None def load_clip_filtering_aux_info(vid): try: with open("{}/{}_aux_info.json".format(my_config.CLIP_PATH, vid)) as data_file: data = json.load(data_file) return data except FileNotFoundError: return None ################################################################################# #SUBTITLE class SubtitleWrapper: TIMESTAMP_PATTERN = re.compile('(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})') def __init__(self, vid, mode): self.subtitle = [] if mode == 'auto': self.load_auto_subtitle_data(vid) elif mode == 'gentle': self.laod_gentle_subtitle(vid) def get(self): return self.subtitle # using gentle lib def laod_gentle_subtitle(self,vid): try: with open("{}/{}_align_results.json".format(my_config.VIDEO_PATH, vid)) as data_file: data = json.load(data_file) if 'words' in data: raw_subtitle = data['words'] for word in raw_subtitle : if word['case'] == 'success': self.subtitle.append(word) else: self.subtitle = None return data except FileNotFoundError: self.subtitle = None # using youtube automatic subtitle def load_auto_subtitle_data(self, vid): lang = my_config.LANG postfix_in_filename = '-'+lang+'-auto.vtt' file_list = glob.glob(my_config.SUBTITLE_PATH + '/*' + vid + postfix_in_filename) if len(file_list) > 1: print('more than one subtitle. check this.', file_list) self.subtitle = None assert False if len(file_list) == 1: for i, subtitle_chunk in enumerate(WebVTT().read(file_list[0])): raw_subtitle = str(subtitle_chunk.raw_text) if raw_subtitle.find('\n'): raw_subtitle = raw_subtitle.split('\n') for raw_subtitle_chunk in raw_subtitle: if self.TIMESTAMP_PATTERN.search(raw_subtitle_chunk) is None: continue # removes html tags and timing tags from caption text raw_subtitle_chunk = raw_subtitle_chunk.replace("</c>", "") raw_subtitle_chunk = re.sub("<c[.]\w+>", '', raw_subtitle_chunk) word_list = [] raw_subtitle_s = subtitle_chunk.start_in_seconds raw_subtitle_e = subtitle_chunk.end_in_seconds word_chunk = raw_subtitle_chunk.split('<c>') for i, word in enumerate(word_chunk): word_info = {} if i == len(word_chunk)-1: word_info['word'] = word word_info['start'] = word_list[i-1]['end'] word_info['end'] = raw_subtitle_e word_list.append(word_info) break word = word.split("<") word_info['word'] = word[0] word_info['end'] = self.get_seconds(word[1][:-1]) if i == 0: word_info['start'] = raw_subtitle_s word_list.append(word_info) continue word_info['start'] = word_list[i-1]['end'] word_list.append(word_info) self.subtitle.extend(word_list) else: print('subtitle file is not exist') self.subtitle = None # convert timestamp to second def get_seconds(self, word_time_e): time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e) if not time_value: print('wrong time stamp pattern') exit() values = list(map(lambda x: int(x) if x else 0, time_value.groups())) hours, minutes, seconds, milliseconds = values[0], values[1], values[2], values[3] return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000