Initial commit

This commit is contained in:
prajwalkr 2020-08-17 22:38:19 +05:30
parent be07dd0c05
commit db7bc44a86
35 changed files with 4047 additions and 2 deletions

12
.gitignore vendored Normal file
View file

@ -0,0 +1,12 @@
*.pkl
*.jpg
*.mp4
*.pth
*.pyc
__pycache__
*.h5
*.avi
*.wav
filelists/*.txt
evaluation/test_filelists/lr*.txt
*.pyc

104
README.md
View file

@ -1,2 +1,102 @@
# Wav2Lip
This repository will be having the codes of our recent Lipsync paper which by far out classes anything else
# **Wav2Lip**: *Accurately Lip-syncing Videos In The Wild*
This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild_ published at ACM Multimedia 2020.
[[Paper]](#) | [[Project Page]](http://cvit.iiit.ac.in/research/projects/cvit-projects/a-lip-sync-expert-is-all-you-need-for-speech-to-lip-generation-in-the-wild/) | [[Demo Video]](#) | [[Interactive Demo]](#) | [[ReSyncED]](#)
----------
**Highlights**
----------
- Lip-sync videos to any target speech with high accuracy. Try our [interactive demo](#).
- Works for any identity, voice, and language. Also works for CGI faces and synthetic voices.
- Complete training code, inference code, and pretrained models are available.
- Or, quick-start with Google Colab: [Link](#)
- Several new, reliable evaluation benchmarks and metrics [[`evaluation/` folder of this repo]](https://github.com/Rudrabha/Wav2Lip/tree/master/evaluation) released.
- Code to calculate metrics reported in the paper is also made available.
Prerequisites
-------------
- `Python 3.5.2` (code has been tested with this version)
- ffmpeg: `sudo apt-get install ffmpeg`
- Install necessary packages using `pip install -r requirements.txt`
- Face detection [pre-trained model](https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth) should be downloaded to `face_detection/detection/sfd/s3fd.pth`
Getting the weights
----------
| Model | Description | Link to the model |
| :-------------: | :---------------: | :---------------: |
| Wav2Lip | Highly accurate lip-sync | [Link](#) |
| Wav2Lip + GAN | Slightly inferior lip-sync, but better visual quality | [Link](#) |
Lip-syncing videos using the pre-trained models (Inference)
-------
You can lip-sync any video to any audio:
```bash
python inference.py --checkpoint_path <ckpt> --face <video.mp4> --audio <an-audio-source>
```
The result is saved (by default) in `results/result_voice.mp4`. You can specify it as an argument, similar to several other available options. The audio source can be any file supported by `FFMPEG` containing audio data: `*.wav`, `*.mp3` or even a video file, from which the code will automatically extract the audio.
##### Tips for better results:
- Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`.
- Experiment with the `--resize_factor` argument, to get a lower resolution video. Why? The models are trained on faces which were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too).
- The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well.
Preparing LRS2 for training
----------
Our models are trained on LRS2. Training on other datasets might require small modifications to the code.
##### LRS2 dataset folder structure
```
data_root (mvlrs_v1)
├── main, pretrain (we use only main folder in this work)
| ├── list of folders
| │ ├── five-digit numbered video IDs ending with (.mp4)
```
Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` folder.
##### Preprocess the dataset for fast training
```bash
python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/
```
Additional options like `batch_size` and number of GPUs to use in parallel to use can also be set.
##### Preprocessed LRS2 folder structure
```
preprocessed_root (lrs2_preprocessed)
├── list of folders
| ├── Folders with five-digit numbered video IDs
| │ ├── *.jpg
| │ ├── audio.wav
```
Train!
----------
There are two major steps: (i) Train the expert lip-sync discriminator, (ii) Train the Wav2Lip model(s).
##### Training the expert discriminator
You can download [the pre-trained weights]() if you want to skip this step. To train it:
```bash
python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints>
```
##### Training the Wav2Lip models
You can either train the model without the additional visual quality disriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run:
```bash
python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints> --syncnet_checkpoint_path <path_to_expert_disc_checkpoint>
```
To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both the files are similar. In both the cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file.
Evaluation
----------
Will be updated.
License and Citation
----------
The software is licensed under the MIT License. Please cite the following paper if you have use this code: will be updated.
Acknowledgements
----------
Parts of the code structure is inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models.

136
audio.py Normal file
View file

@ -0,0 +1,136 @@
import librosa
import librosa.filters
import numpy as np
import tensorflow as tf
from scipy import signal
from scipy.io import wavfile
from hparams import hparams as hp
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
#proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = hp.hop_size
if hop_size is None:
assert hp.frame_shift_ms is not None
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
def _stft(y):
if hp.use_lws:
return _lws_processor(hp).stft(y).T
else:
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
##########################################################
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram
"""
pad = (fsize - fshift)
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding
"""
M = num_frames(len(x), fsize, fshift)
pad = (fsize - fshift)
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
fmin=hp.fmin, fmax=hp.fmax)
def _amp_to_db(x):
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
-hp.max_abs_value, hp.max_abs_value)
else:
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
if hp.symmetric_mels:
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
else:
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
def _denormalize(D):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return (((np.clip(D, -hp.max_abs_value,
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ hp.min_level_db)
else:
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
if hp.symmetric_mels:
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
else:
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)

1
checkpoints/README.md Normal file
View file

@ -0,0 +1 @@
Place all your checkpoints (.pth files) here.

281
color_syncnet_train.py Normal file
View file

@ -0,0 +1,281 @@
from os.path import dirname, join, basename, isfile
from tqdm import tqdm
from models import SyncNet_color as SyncNet
import audio
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list
parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
args = parser.parse_args()
global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
syncnet_T = 5
syncnet_mel_step_size = 16
class Dataset(object):
def __init__(self, split):
self.all_videos = get_image_list(args.data_root, split)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.jpg'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def crop_audio_window(self, spec, start_frame):
# num_frames = (T x hop_size * fps) / sample_rate
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx : end_idx, :]
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
vidname = self.all_videos[idx]
img_names = list(glob(join(vidname, '*.jpg')))
if len(img_names) <= 3 * syncnet_T:
continue
img_name = random.choice(img_names)
wrong_img_name = random.choice(img_names)
while wrong_img_name == img_name:
wrong_img_name = random.choice(img_names)
if random.choice([True, False]):
y = torch.ones(1).float()
chosen = img_name
else:
y = torch.zeros(1).float()
chosen = wrong_img_name
window_fnames = self.get_window(chosen)
if window_fnames is None:
continue
window = []
all_read = True
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
all_read = False
break
try:
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
except Exception as e:
all_read = False
break
window.append(img)
if not all_read: continue
try:
wavpath = join(vidname, "audio.wav")
wav = audio.load_wav(wavpath, hparams.sample_rate)
orig_mel = audio.melspectrogram(wav).T
except Exception as e:
continue
mel = self.crop_audio_window(orig_mel.copy(), img_name)
if (mel.shape[0] != syncnet_mel_step_size):
continue
# H x W x 3 * T
x = np.concatenate(window, axis=2) / 255.
x = x.transpose(2, 0, 1)
x = x[:, x.shape[1]//2:]
x = torch.FloatTensor(x)
mel = torch.FloatTensor(mel.T).unsqueeze(0)
return x, mel, y
logloss = nn.BCELoss()
def cosine_loss(a, v, y):
d = nn.functional.cosine_similarity(a, v)
loss = logloss(d.unsqueeze(1), y)
return loss
def train(device, model_single, train_data_loader, test_data_loader, optimizer,
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
model = nn.DataParallel(model_single)
global global_step, global_epoch
resumed_step = global_step
while global_epoch < nepochs:
running_loss = 0.
prog_bar = tqdm(enumerate(train_data_loader))
for step, (x, mel, y) in prog_bar:
model.train()
optimizer.zero_grad()
# Transform data to CUDA device
x = x.to(device)
mel = mel.to(device)
a, v = model(mel, x)
y = y.to(device)
loss = cosine_loss(a, v, y)
loss.backward()
optimizer.step()
global_step += 1
cur_session_steps = global_step - resumed_step
running_loss += loss.item()
if global_step == 1 or global_step % checkpoint_interval == 0:
save_checkpoint(
model_single, optimizer, global_step, checkpoint_dir, global_epoch)
if global_step % hparams.syncnet_eval_interval == 0:
with torch.no_grad():
eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
global_epoch += 1
def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
eval_steps = 1400
print('Evaluating for {} steps'.format(eval_steps))
losses = []
while 1:
for step, (x, mel, y) in enumerate(test_data_loader):
model.eval()
# Transform data to CUDA device
x = x.to(device)
mel = mel.to(device)
a, v = model(mel, x)
y = y.to(device)
loss = cosine_loss(a, v, y)
losses.append(loss.item())
if step > eval_steps: break
averaged_loss = sum(losses) / len(losses)
print(averaged_loss)
return
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
checkpoint_path = join(
checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
torch.save({
"state_dict": model.state_dict(),
"optimizer": optimizer_state,
"global_step": step,
"global_epoch": epoch,
}, checkpoint_path)
print("Saved checkpoint:", checkpoint_path)
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_checkpoint(path, model, optimizer, reset_optimizer=False):
global global_step
global global_epoch
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
model.load_state_dict(checkpoint["state_dict"])
if not reset_optimizer:
optimizer_state = checkpoint["optimizer"]
if optimizer_state is not None:
print("Load optimizer state from {}".format(path))
optimizer.load_state_dict(checkpoint["optimizer"])
global_step = checkpoint["global_step"]
global_epoch = checkpoint["global_epoch"]
return model
if __name__ == "__main__":
checkpoint_dir = args.checkpoint_dir
checkpoint_path = args.checkpoint_path
if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
# Dataset and Dataloader setup
train_dataset = Dataset('train')
test_dataset = Dataset('val')
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
num_workers=hparams.num_workers)
test_data_loader = data_utils.DataLoader(
test_dataset, batch_size=hparams.syncnet_batch_size,
num_workers=8)
device = torch.device("cuda" if use_cuda else "cpu")
# Model
model = SyncNet().to(device)
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
lr=hparams.syncnet_lr)
if checkpoint_path is not None:
load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)
train(device, model, train_data_loader, test_data_loader, optimizer,
checkpoint_dir=checkpoint_dir,
checkpoint_interval=hparams.syncnet_checkpoint_interval,
nepochs=hparams.nepochs)

View file

@ -0,0 +1,238 @@
from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse
import dlib, json, subprocess
from tqdm import tqdm
from glob import glob
import torch
sys.path.append('../')
import audio
import face_detection
from models import Wav2Lip
parser = argparse.ArgumentParser(description='Code to generate results for test filelists')
parser.add_argument('--filelist', type=str,
help='Filepath of filelist file to read', required=True)
parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
required=True)
parser.add_argument('--data_root', type=str, required=True)
parser.add_argument('--checkpoint_path', type=str,
help='Name of saved checkpoint to load weights from', required=True)
parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0],
help='Padding (top, bottom, left, right)')
parser.add_argument('--face_det_batch_size', type=int,
help='Single GPU batch size for face detection', default=64)
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
# parser.add_argument('--resize_factor', default=1, type=int)
args = parser.parse_args()
args.img_size = 96
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
batch_size = args.face_det_batch_size
while 1:
predictions = []
try:
for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError('Image too big to run face detection on GPU')
batch_size //= 2
args.face_det_batch_size = batch_size
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = args.pads
for rect, image in zip(predictions, images):
if rect is None:
raise ValueError('Face not detected!')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = get_smoothened_boxes(np.array(results), T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
return results
def datagen(frames, face_det_results, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
for i, m in enumerate(mels):
if i >= len(frames): raise ValueError('Equal or less lengths only')
frame_to_save = frames[i].copy()
face, coords, valid_frame = face_det_results[i].copy()
if not valid_frame:
continue
face = cv2.resize(face, (args.img_size, args.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= args.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
fps = 25
mel_step_size = 16
mel_idx_multiplier = 80./fps
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
model = load_model(args.checkpoint_path)
def main():
assert args.data_root is not None
data_root = args.data_root
if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
with open(args.filelist, 'r') as filelist:
lines = filelist.readlines()
for idx, line in enumerate(tqdm(lines)):
audio_src, video = line.strip().split()
audio_src = os.path.join(data_root, audio_src) + '.mp4'
video = os.path.join(data_root, video) + '.mp4'
command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
subprocess.call(command, shell=True)
temp_audio = '../temp/temp.wav'
wav = audio.load_wav(temp_audio, 16000)
mel = audio.melspectrogram(wav)
if np.isnan(mel.reshape(-1)).sum() > 0:
continue
mel_chunks = []
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
break
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
i += 1
video_stream = cv2.VideoCapture(video)
full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading or len(full_frames) > len(mel_chunks):
video_stream.release()
break
full_frames.append(frame)
if len(full_frames) < len(mel_chunks):
continue
full_frames = full_frames[:len(mel_chunks)]
try:
face_det_results = face_detect(full_frames.copy())
except ValueError as e:
continue
batch_size = args.wav2lip_batch_size
gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
if i == 0:
frame_h, frame_w = full_frames[0].shape[:-1]
out = cv2.VideoWriter('../temp/result.avi',
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for pl, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = pl
out.write(f)
out.release()
vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio,
'../temp/result.avi', vid)
subprocess.call(command, shell=True)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,305 @@
from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse
import dlib, json, subprocess
from tqdm import tqdm
from glob import glob
import torch
sys.path.append('../')
import audio
import face_detection
from models import Wav2Lip
parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set')
parser.add_argument('--mode', type=str,
help='random | dubbed | tts', required=True)
parser.add_argument('--filelist', type=str,
help='Filepath of filelist file to read', default=None)
parser.add_argument('--results_dir', type=str, help='Folder to save all results into',
required=True)
parser.add_argument('--data_root', type=str, required=True)
parser.add_argument('--checkpoint_path', type=str,
help='Name of saved checkpoint to load weights from', required=True)
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
help='Padding (top, bottom, left, right)')
parser.add_argument('--face_det_batch_size', type=int,
help='Single GPU batch size for face detection', default=16)
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128)
parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180)
parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480)
parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720)
# parser.add_argument('--resize_factor', default=1, type=int)
args = parser.parse_args()
args.img_size = 96
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def rescale_frames(images):
rect = detector.get_detections_for_batch(np.array([images[0]]))[0]
if rect is None:
raise ValueError('Face not detected!')
h, w = images[0].shape[:-1]
x1, y1, x2, y2 = rect
face_size = max(np.abs(y1 - y2), np.abs(x1 - x2))
diff = np.abs(face_size - args.face_res)
for factor in range(2, 16):
downsampled_res = face_size // factor
if min(h//factor, w//factor) < args.min_frame_res: break
if np.abs(downsampled_res - args.face_res) >= diff: break
factor -= 1
if factor == 1: return images
return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images]
def face_detect(images):
batch_size = args.face_det_batch_size
images = rescale_frames(images)
while 1:
predictions = []
try:
for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError('Image too big to run face detection on GPU')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = args.pads
for rect, image in zip(predictions, images):
if rect is None:
raise ValueError('Face not detected!')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = get_smoothened_boxes(np.array(results), T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
return results, images
def datagen(frames, face_det_results, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
for i, m in enumerate(mels):
if i >= len(frames): raise ValueError('Equal or less lengths only')
frame_to_save = frames[i].copy()
face, coords, valid_frame = face_det_results[i].copy()
if not valid_frame:
continue
face = cv2.resize(face, (args.img_size, args.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= args.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
def increase_frames(frames, l):
## evenly duplicating frames to increase length of video
while len(frames) < l:
dup_every = float(l) / len(frames)
final_frames = []
next_duplicate = 0.
for i, f in enumerate(frames):
final_frames.append(f)
if int(np.ceil(next_duplicate)) == i:
final_frames.append(f)
next_duplicate += dup_every
frames = final_frames
return frames[:l]
mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
model = load_model(args.checkpoint_path)
def main():
if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)
if args.mode == 'dubbed':
files = listdir(args.data_root)
lines = ['{} {}'.format(f, f) for f in files]
else:
assert args.filelist is not None
with open(args.filelist, 'r') as filelist:
lines = filelist.readlines()
for idx, line in enumerate(tqdm(lines)):
video, audio_src = line.strip().split()
audio_src = os.path.join(args.data_root, audio_src)
video = os.path.join(args.data_root, video)
command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav')
subprocess.call(command, shell=True)
temp_audio = '../temp/temp.wav'
wav = audio.load_wav(temp_audio, 16000)
mel = audio.melspectrogram(wav)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError('Mel contains nan!')
video_stream = cv2.VideoCapture(video)
fps = video_stream.get(cv2.CAP_PROP_FPS)
mel_idx_multiplier = 80./fps
full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
if min(frame.shape[:-1]) > args.max_frame_res:
h, w = frame.shape[:-1]
scale_factor = min(h, w) / float(args.max_frame_res)
h = int(h/scale_factor)
w = int(w/scale_factor)
frame = cv2.resize(frame, (w, h))
full_frames.append(frame)
mel_chunks = []
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
break
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
i += 1
if len(full_frames) < len(mel_chunks):
if args.mode == 'tts':
full_frames = increase_frames(full_frames, len(mel_chunks))
else:
raise ValueError('#Frames, audio length mismatch')
else:
full_frames = full_frames[:len(mel_chunks)]
try:
face_det_results, full_frames = face_detect(full_frames.copy())
except ValueError as e:
continue
batch_size = args.wav2lip_batch_size
gen = datagen(full_frames.copy(), face_det_results, mel_chunks)
for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
if i == 0:
frame_h, frame_w = full_frames[0].shape[:-1]
out = cv2.VideoWriter('../temp/result.avi',
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for pl, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = pl
out.write(f)
out.release()
vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav',
'../temp/result.avi', vid)
subprocess.call(command, shell=True)
if __name__ == '__main__':
main()

View file

@ -0,0 +1 @@
Place LRS2, LRW, and LRS3 (and any other) test set filelists here.

View file

@ -0,0 +1,160 @@
sachin.mp4 emma_cropped.mp4
sachin.mp4 mourinho.mp4
sachin.mp4 elon.mp4
sachin.mp4 messi2.mp4
sachin.mp4 cr1.mp4
sachin.mp4 sachin.mp4
sachin.mp4 sg.mp4
sachin.mp4 fergi.mp4
sachin.mp4 spanish_lec1.mp4
sachin.mp4 bush_small.mp4
sachin.mp4 macca_cut.mp4
sachin.mp4 ca_cropped.mp4
sachin.mp4 lecun.mp4
sachin.mp4 spanish_lec0.mp4
srk.mp4 emma_cropped.mp4
srk.mp4 mourinho.mp4
srk.mp4 elon.mp4
srk.mp4 messi2.mp4
srk.mp4 cr1.mp4
srk.mp4 srk.mp4
srk.mp4 sachin.mp4
srk.mp4 sg.mp4
srk.mp4 fergi.mp4
srk.mp4 spanish_lec1.mp4
srk.mp4 bush_small.mp4
srk.mp4 macca_cut.mp4
srk.mp4 ca_cropped.mp4
srk.mp4 guardiola.mp4
srk.mp4 lecun.mp4
srk.mp4 spanish_lec0.mp4
cr1.mp4 emma_cropped.mp4
cr1.mp4 elon.mp4
cr1.mp4 messi2.mp4
cr1.mp4 cr1.mp4
cr1.mp4 spanish_lec1.mp4
cr1.mp4 bush_small.mp4
cr1.mp4 macca_cut.mp4
cr1.mp4 ca_cropped.mp4
cr1.mp4 lecun.mp4
cr1.mp4 spanish_lec0.mp4
macca_cut.mp4 emma_cropped.mp4
macca_cut.mp4 elon.mp4
macca_cut.mp4 messi2.mp4
macca_cut.mp4 spanish_lec1.mp4
macca_cut.mp4 macca_cut.mp4
macca_cut.mp4 ca_cropped.mp4
macca_cut.mp4 spanish_lec0.mp4
lecun.mp4 emma_cropped.mp4
lecun.mp4 elon.mp4
lecun.mp4 messi2.mp4
lecun.mp4 spanish_lec1.mp4
lecun.mp4 macca_cut.mp4
lecun.mp4 ca_cropped.mp4
lecun.mp4 lecun.mp4
lecun.mp4 spanish_lec0.mp4
messi2.mp4 emma_cropped.mp4
messi2.mp4 elon.mp4
messi2.mp4 messi2.mp4
messi2.mp4 spanish_lec1.mp4
messi2.mp4 macca_cut.mp4
messi2.mp4 ca_cropped.mp4
messi2.mp4 spanish_lec0.mp4
ca_cropped.mp4 emma_cropped.mp4
ca_cropped.mp4 elon.mp4
ca_cropped.mp4 spanish_lec1.mp4
ca_cropped.mp4 ca_cropped.mp4
ca_cropped.mp4 spanish_lec0.mp4
spanish_lec1.mp4 spanish_lec1.mp4
spanish_lec1.mp4 spanish_lec0.mp4
elon.mp4 elon.mp4
elon.mp4 spanish_lec1.mp4
elon.mp4 spanish_lec0.mp4
guardiola.mp4 emma_cropped.mp4
guardiola.mp4 mourinho.mp4
guardiola.mp4 elon.mp4
guardiola.mp4 messi2.mp4
guardiola.mp4 cr1.mp4
guardiola.mp4 sachin.mp4
guardiola.mp4 sg.mp4
guardiola.mp4 fergi.mp4
guardiola.mp4 spanish_lec1.mp4
guardiola.mp4 bush_small.mp4
guardiola.mp4 macca_cut.mp4
guardiola.mp4 ca_cropped.mp4
guardiola.mp4 guardiola.mp4
guardiola.mp4 lecun.mp4
guardiola.mp4 spanish_lec0.mp4
fergi.mp4 emma_cropped.mp4
fergi.mp4 mourinho.mp4
fergi.mp4 elon.mp4
fergi.mp4 messi2.mp4
fergi.mp4 cr1.mp4
fergi.mp4 sachin.mp4
fergi.mp4 sg.mp4
fergi.mp4 fergi.mp4
fergi.mp4 spanish_lec1.mp4
fergi.mp4 bush_small.mp4
fergi.mp4 macca_cut.mp4
fergi.mp4 ca_cropped.mp4
fergi.mp4 lecun.mp4
fergi.mp4 spanish_lec0.mp4
spanish.mp4 emma_cropped.mp4
spanish.mp4 spanish.mp4
spanish.mp4 mourinho.mp4
spanish.mp4 elon.mp4
spanish.mp4 messi2.mp4
spanish.mp4 cr1.mp4
spanish.mp4 srk.mp4
spanish.mp4 sachin.mp4
spanish.mp4 sg.mp4
spanish.mp4 fergi.mp4
spanish.mp4 spanish_lec1.mp4
spanish.mp4 bush_small.mp4
spanish.mp4 macca_cut.mp4
spanish.mp4 ca_cropped.mp4
spanish.mp4 guardiola.mp4
spanish.mp4 lecun.mp4
spanish.mp4 spanish_lec0.mp4
bush_small.mp4 emma_cropped.mp4
bush_small.mp4 elon.mp4
bush_small.mp4 messi2.mp4
bush_small.mp4 spanish_lec1.mp4
bush_small.mp4 bush_small.mp4
bush_small.mp4 macca_cut.mp4
bush_small.mp4 ca_cropped.mp4
bush_small.mp4 lecun.mp4
bush_small.mp4 spanish_lec0.mp4
emma_cropped.mp4 emma_cropped.mp4
emma_cropped.mp4 elon.mp4
emma_cropped.mp4 spanish_lec1.mp4
emma_cropped.mp4 spanish_lec0.mp4
sg.mp4 emma_cropped.mp4
sg.mp4 mourinho.mp4
sg.mp4 elon.mp4
sg.mp4 messi2.mp4
sg.mp4 cr1.mp4
sg.mp4 sachin.mp4
sg.mp4 sg.mp4
sg.mp4 fergi.mp4
sg.mp4 spanish_lec1.mp4
sg.mp4 bush_small.mp4
sg.mp4 macca_cut.mp4
sg.mp4 ca_cropped.mp4
sg.mp4 lecun.mp4
sg.mp4 spanish_lec0.mp4
spanish_lec0.mp4 spanish_lec0.mp4
mourinho.mp4 emma_cropped.mp4
mourinho.mp4 mourinho.mp4
mourinho.mp4 elon.mp4
mourinho.mp4 messi2.mp4
mourinho.mp4 cr1.mp4
mourinho.mp4 sachin.mp4
mourinho.mp4 sg.mp4
mourinho.mp4 fergi.mp4
mourinho.mp4 spanish_lec1.mp4
mourinho.mp4 bush_small.mp4
mourinho.mp4 macca_cut.mp4
mourinho.mp4 ca_cropped.mp4
mourinho.mp4 lecun.mp4
mourinho.mp4 spanish_lec0.mp4

View file

@ -0,0 +1,18 @@
adam_1.mp4 andreng_optimization.wav
agad_2.mp4 agad_2.wav
agad_1.mp4 agad_1.wav
agad_3.mp4 agad_3.wav
rms_prop_1.mp4 rms_prop_tts.wav
tf_1.mp4 tf_1.wav
tf_2.mp4 tf_2.wav
andrew_ng_ai_business.mp4 andrewng_business_tts.wav
covid_autopsy_1.mp4 autopsy_tts.wav
news_1.mp4 news_tts.wav
andrew_ng_fund_1.mp4 andrewng_ai_fund.wav
covid_treatments_1.mp4 covid_tts.wav
pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav
pytorch_1.mp4 pytorch.wav
pkb_1.mp4 pkb_1.wav
ss_1.mp4 ss_1.wav
carlsen_1.mp4 carlsen_eng.wav
french.mp4 french.wav

1
face_detection/README.md Normal file
View file

@ -0,0 +1 @@
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.

View file

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
__author__ = """Adrian Bulat"""
__email__ = 'adrian.bulat@nottingham.ac.uk'
__version__ = '1.0.1'
from .api import FaceAlignment, LandmarksType, NetworkSize

79
face_detection/api.py Normal file
View file

@ -0,0 +1,79 @@
from __future__ import print_function
import os
import torch
from torch.utils.model_zoo import load_url
from enum import Enum
import numpy as np
import cv2
try:
import urllib.request as request_file
except BaseException:
import urllib as request_file
from .models import FAN, ResNetDepth
from .utils import *
class LandmarksType(Enum):
"""Enum class defining the type of landmarks to detect.
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
``_2halfD`` - this points represent the projection of the 3D points into 3D
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
"""
_2D = 1
_2halfD = 2
_3D = 3
class NetworkSize(Enum):
# TINY = 1
# SMALL = 2
# MEDIUM = 3
LARGE = 4
def __new__(cls, value):
member = object.__new__(cls)
member._value_ = value
return member
def __int__(self):
return self.value
ROOT = os.path.dirname(os.path.abspath(__file__))
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results

View file

@ -0,0 +1 @@
from .core import FaceDetector

View file

@ -0,0 +1,130 @@
import logging
import glob
from tqdm import tqdm
import numpy as np
import torch
import cv2
class FaceDetector(object):
"""An abstract class representing a face detector.
Any other face detection implementation must subclass it. All subclasses
must implement ``detect_from_image``, that return a list of detected
bounding boxes. Optionally, for speed considerations detect from path is
recommended.
"""
def __init__(self, device, verbose):
self.device = device
self.verbose = verbose
if verbose:
if 'cpu' in device:
logger = logging.getLogger(__name__)
logger.warning("Detection running on CPU, this may be potentially slow.")
if 'cpu' not in device and 'cuda' not in device:
if verbose:
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
raise ValueError
def detect_from_image(self, tensor_or_path):
"""Detects faces in a given image.
This function detects the faces present in a provided BGR(usually)
image. The input can be either the image itself or the path to it.
Arguments:
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
to an image or the image itself.
Example::
>>> path_to_image = 'data/image_01.jpg'
... detected_faces = detect_from_image(path_to_image)
[A list of bounding boxes (x1, y1, x2, y2)]
>>> image = cv2.imread(path_to_image)
... detected_faces = detect_from_image(image)
[A list of bounding boxes (x1, y1, x2, y2)]
"""
raise NotImplementedError
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
"""Detects faces from all the images present in a given directory.
Arguments:
path {string} -- a string containing a path that points to the folder containing the images
Keyword Arguments:
extensions {list} -- list of string containing the extensions to be
consider in the following format: ``.extension_name`` (default:
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
folder recursively (default: {False}) show_progress_bar {bool} --
display a progressbar (default: {True})
Example:
>>> directory = 'data'
... detected_faces = detect_from_directory(directory)
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
"""
if self.verbose:
logger = logging.getLogger(__name__)
if len(extensions) == 0:
if self.verbose:
logger.error("Expected at list one extension, but none was received.")
raise ValueError
if self.verbose:
logger.info("Constructing the list of images.")
additional_pattern = '/**/*' if recursive else '/*'
files = []
for extension in extensions:
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
if self.verbose:
logger.info("Finished searching for images. %s images found", len(files))
logger.info("Preparing to run the detection.")
predictions = {}
for image_path in tqdm(files, disable=not show_progress_bar):
if self.verbose:
logger.info("Running the face detector on image: %s", image_path)
predictions[image_path] = self.detect_from_image(image_path)
if self.verbose:
logger.info("The detector was successfully run on all %s images", len(files))
return predictions
@property
def reference_scale(self):
raise NotImplementedError
@property
def reference_x_shift(self):
raise NotImplementedError
@property
def reference_y_shift(self):
raise NotImplementedError
@staticmethod
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
Arguments:
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
"""
if isinstance(tensor_or_path, str):
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
elif torch.is_tensor(tensor_or_path):
# Call cpu in case its coming from cuda
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
elif isinstance(tensor_or_path, np.ndarray):
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
else:
raise TypeError

View file

@ -0,0 +1 @@
from .sfd_detector import SFDDetector as FaceDetector

View file

@ -0,0 +1,129 @@
from __future__ import print_function
import os
import sys
import cv2
import random
import datetime
import time
import math
import argparse
import numpy as np
import torch
try:
from iou import IOU
except BaseException:
# IOU cython speedup 10x
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
sa = abs((ax2 - ax1) * (ay2 - ay1))
sb = abs((bx2 - bx1) * (by2 - by1))
x1, y1 = max(ax1, bx1), max(ay1, by1)
x2, y2 = min(ax2, bx2), min(ay2, by2)
w = x2 - x1
h = y2 - y1
if w < 0 or h < 0:
return 0.0
else:
return 1.0 * w * h / (sa + sb - w * h)
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
dw, dh = math.log(ww / aww), math.log(hh / ahh)
return dx, dy, dw, dh
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
xc, yc = dx * aww + axc, dy * ahh + ayc
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
return x1, y1, x2, y2
def nms(dets, thresh):
if 0 == len(dets):
return []
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 4].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded boxes (tensor), Shape: [num_priors, 4]
"""
# dist b/t match center and prior's center
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, 2:])
# match wh / prior wh
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# return target for smooth_l1_loss
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def batch_decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes

View file

@ -0,0 +1,112 @@
import torch
import torch.nn.functional as F
import os
import sys
import cv2
import random
import datetime
import math
import argparse
import numpy as np
import scipy.io as sio
import zipfile
from .net_s3fd import s3fd
from .bbox import *
def detect(net, img, device):
img = img - np.array([104, 117, 123])
img = img.transpose(2, 0, 1)
img = img.reshape((1,) + img.shape)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
img = torch.from_numpy(img).float().to(device)
BB, CC, HH, WW = img.size()
with torch.no_grad():
olist = net(img)
bboxlist = []
for i in range(len(olist) // 2):
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
olist = [oelem.data.cpu() for oelem in olist]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
FB, FC, FH, FW = ocls.size() # feature map size
stride = 2**(i + 2) # 4,8,16,32,64,128
anchor = stride * 4
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
for Iindex, hindex, windex in poss:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[0, 1, hindex, windex]
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
variances = [0.1, 0.2]
box = decode(loc, priors, variances)
x1, y1, x2, y2 = box[0] * 1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append([x1, y1, x2, y2, score])
bboxlist = np.array(bboxlist)
if 0 == len(bboxlist):
bboxlist = np.zeros((1, 5))
return bboxlist
def batch_detect(net, imgs, device):
imgs = imgs - np.array([104, 117, 123])
imgs = imgs.transpose(0, 3, 1, 2)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
imgs = torch.from_numpy(imgs).float().to(device)
BB, CC, HH, WW = imgs.size()
with torch.no_grad():
olist = net(imgs)
bboxlist = []
for i in range(len(olist) // 2):
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
olist = [oelem.data.cpu() for oelem in olist]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
FB, FC, FH, FW = ocls.size() # feature map size
stride = 2**(i + 2) # 4,8,16,32,64,128
anchor = stride * 4
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
for Iindex, hindex, windex in poss:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[:, 1, hindex, windex]
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
variances = [0.1, 0.2]
box = batch_decode(loc, priors, variances)
box = box[:, 0] * 1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
bboxlist = np.array(bboxlist)
if 0 == len(bboxlist):
bboxlist = np.zeros((1, BB, 5))
return bboxlist
def flip_detect(net, img, device):
img = cv2.flip(img, 1)
b = detect(net, img, device)
bboxlist = np.zeros(b.shape)
bboxlist[:, 0] = img.shape[1] - b[:, 2]
bboxlist[:, 1] = b[:, 1]
bboxlist[:, 2] = img.shape[1] - b[:, 0]
bboxlist[:, 3] = b[:, 3]
bboxlist[:, 4] = b[:, 4]
return bboxlist
def pts_to_bb(pts):
min_x, min_y = np.min(pts, axis=0)
max_x, max_y = np.max(pts, axis=0)
return np.array([min_x, min_y, max_x, max_y])

View file

@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2Norm(nn.Module):
def __init__(self, n_channels, scale=1.0):
super(L2Norm, self).__init__()
self.n_channels = n_channels
self.scale = scale
self.eps = 1e-10
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
self.weight.data *= 0.0
self.weight.data += self.scale
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
x = x / norm * self.weight.view(1, -1, 1, 1)
return x
class s3fd(nn.Module):
def __init__(self):
super(s3fd, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.conv3_3_norm = L2Norm(256, scale=10)
self.conv4_3_norm = L2Norm(512, scale=8)
self.conv5_3_norm = L2Norm(512, scale=5)
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
def forward(self, x):
h = F.relu(self.conv1_1(x))
h = F.relu(self.conv1_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
f3_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
f4_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv5_1(h))
h = F.relu(self.conv5_2(h))
h = F.relu(self.conv5_3(h))
f5_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.fc6(h))
h = F.relu(self.fc7(h))
ffc7 = h
h = F.relu(self.conv6_1(h))
h = F.relu(self.conv6_2(h))
f6_2 = h
h = F.relu(self.conv7_1(h))
h = F.relu(self.conv7_2(h))
f7_2 = h
f3_3 = self.conv3_3_norm(f3_3)
f4_3 = self.conv4_3_norm(f4_3)
f5_3 = self.conv5_3_norm(f5_3)
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
cls4 = self.fc7_mbox_conf(ffc7)
reg4 = self.fc7_mbox_loc(ffc7)
cls5 = self.conv6_2_mbox_conf(f6_2)
reg5 = self.conv6_2_mbox_loc(f6_2)
cls6 = self.conv7_2_mbox_conf(f7_2)
reg6 = self.conv7_2_mbox_loc(f7_2)
# max-out background label
chunk = torch.chunk(cls1, 4, 1)
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
cls1 = torch.cat([bmax, chunk[3]], dim=1)
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]

View file

@ -0,0 +1,59 @@
import os
import cv2
from torch.utils.model_zoo import load_url
from ..core import FaceDetector
from .net_s3fd import s3fd
from .bbox import *
from .detect import *
models_urls = {
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
}
class SFDDetector(FaceDetector):
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
super(SFDDetector, self).__init__(device, verbose)
# Initialise the face detector
if path_to_detector is None:
model_weights = load_url(models_urls['s3fd'])
else:
model_weights = torch.load(path_to_detector)
self.face_detector = s3fd()
self.face_detector.load_state_dict(model_weights)
self.face_detector.to(device)
self.face_detector.eval()
def detect_from_image(self, tensor_or_path):
image = self.tensor_or_path_to_ndarray(tensor_or_path)
bboxlist = detect(self.face_detector, image, device=self.device)
keep = nms(bboxlist, 0.3)
bboxlist = bboxlist[keep, :]
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
return bboxlist
def detect_from_batch(self, images):
bboxlists = batch_detect(self.face_detector, images, device=self.device)
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
return bboxlists
@property
def reference_scale(self):
return 195
@property
def reference_x_shift(self):
return 0
@property
def reference_y_shift(self):
return 0

261
face_detection/models.py Normal file
View file

@ -0,0 +1,261 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=strd, padding=padding, bias=bias)
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(ConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.BatchNorm2d(in_planes),
nn.ReLU(True),
nn.Conv2d(in_planes, out_planes,
kernel_size=1, stride=1, bias=False),
)
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
return out3
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HourGlass(nn.Module):
def __init__(self, num_modules, depth, num_features):
super(HourGlass, self).__init__()
self.num_modules = num_modules
self.depth = depth
self.features = num_features
self._generate_network(self.depth)
def _generate_network(self, level):
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
if level > 1:
self._generate_network(level - 1)
else:
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
def _forward(self, level, inp):
# Upper branch
up1 = inp
up1 = self._modules['b1_' + str(level)](up1)
# Lower branch
low1 = F.avg_pool2d(inp, 2, stride=2)
low1 = self._modules['b2_' + str(level)](low1)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = low1
low2 = self._modules['b2_plus_' + str(level)](low2)
low3 = low2
low3 = self._modules['b3_' + str(level)](low3)
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
return up1 + up2
def forward(self, x):
return self._forward(self.depth, x)
class FAN(nn.Module):
def __init__(self, num_modules=1):
super(FAN, self).__init__()
self.num_modules = num_modules
# Base part
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 128)
self.conv4 = ConvBlock(128, 256)
# Stacking part
for hg_module in range(self.num_modules):
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
self.add_module('conv_last' + str(hg_module),
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
self.add_module('l' + str(hg_module), nn.Conv2d(256,
68, kernel_size=1, stride=1, padding=0))
if hg_module < self.num_modules - 1:
self.add_module(
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('al' + str(hg_module), nn.Conv2d(68,
256, kernel_size=1, stride=1, padding=0))
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)), True)
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
x = self.conv3(x)
x = self.conv4(x)
previous = x
outputs = []
for i in range(self.num_modules):
hg = self._modules['m' + str(i)](previous)
ll = hg
ll = self._modules['top_m_' + str(i)](ll)
ll = F.relu(self._modules['bn_end' + str(i)]
(self._modules['conv_last' + str(i)](ll)), True)
# Predict heatmaps
tmp_out = self._modules['l' + str(i)](ll)
outputs.append(tmp_out)
if i < self.num_modules - 1:
ll = self._modules['bl' + str(i)](ll)
tmp_out_ = self._modules['al' + str(i)](tmp_out)
previous = previous + ll + tmp_out_
return outputs
class ResNetDepth(nn.Module):
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
self.inplanes = 64
super(ResNetDepth, self).__init__()
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

313
face_detection/utils.py Normal file
View file

@ -0,0 +1,313 @@
from __future__ import print_function
import os
import sys
import time
import torch
import math
import numpy as np
import cv2
def _gaussian(
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
mean_vert=0.5):
# handle some defaults
if width is None:
width = size
if height is None:
height = size
if sigma_horz is None:
sigma_horz = sigma
if sigma_vert is None:
sigma_vert = sigma
center_x = mean_horz * width + 0.5
center_y = mean_vert * height + 0.5
gauss = np.empty((height, width), dtype=np.float32)
# generate kernel
for i in range(height):
for j in range(width):
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
if normalize:
gauss = gauss / np.sum(gauss)
return gauss
def draw_gaussian(image, point, sigma):
# Check if the gaussian is inside
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
return image
size = 6 * sigma + 1
g = _gaussian(size)
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
assert (g_x[0] > 0 and g_y[1] > 0)
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
image[image > 1] = 1
return image
def transform(point, center, scale, resolution, invert=False):
"""Generate and affine transformation matrix.
Given a set of points, a center, a scale and a targer resolution, the
function generates and affine transformation matrix. If invert is ``True``
it will produce the inverse transformation.
Arguments:
point {torch.tensor} -- the input 2D point
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
scale {float} -- the scale of the face/object
resolution {float} -- the output resolution
Keyword Arguments:
invert {bool} -- define wherever the function should produce the direct or the
inverse transformation matrix (default: {False})
"""
_pt = torch.ones(3)
_pt[0] = point[0]
_pt[1] = point[1]
h = 200.0 * scale
t = torch.eye(3)
t[0, 0] = resolution / h
t[1, 1] = resolution / h
t[0, 2] = resolution * (-center[0] / h + 0.5)
t[1, 2] = resolution * (-center[1] / h + 0.5)
if invert:
t = torch.inverse(t)
new_point = (torch.matmul(t, _pt))[0:2]
return new_point.int()
def crop(image, center, scale, resolution=256.0):
"""Center crops an image or set of heatmaps
Arguments:
image {numpy.array} -- an rgb image
center {numpy.array} -- the center of the object, usually the same as of the bounding box
scale {float} -- scale of the face
Keyword Arguments:
resolution {float} -- the size of the output cropped image (default: {256.0})
Returns:
[type] -- [description]
""" # Crop around the center point
""" Crops the image around the center. Input is expected to be an np.ndarray """
ul = transform([1, 1], center, scale, resolution, True)
br = transform([resolution, resolution], center, scale, resolution, True)
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
if image.ndim > 2:
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
image.shape[2]], dtype=np.int32)
newImg = np.zeros(newDim, dtype=np.uint8)
else:
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
newImg = np.zeros(newDim, dtype=np.uint8)
ht = image.shape[0]
wd = image.shape[1]
newX = np.array(
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
newY = np.array(
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
interpolation=cv2.INTER_LINEAR)
return newImg
def get_preds_fromhm(hm, center=None, scale=None):
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
and the scale is provided the function will return the points also in
the original coordinate frame.
Arguments:
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
Keyword Arguments:
center {torch.tensor} -- the center of the bounding box (default: {None})
scale {float} -- face scale (default: {None})
"""
max, idx = torch.max(
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
idx += 1
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
for i in range(preds.size(0)):
for j in range(preds.size(1)):
hm_ = hm[i, j, :]
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = torch.FloatTensor(
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
preds[i, j].add_(diff.sign_().mul_(.25))
preds.add_(-.5)
preds_orig = torch.zeros(preds.size())
if center is not None and scale is not None:
for i in range(hm.size(0)):
for j in range(hm.size(1)):
preds_orig[i, j] = transform(
preds[i, j], center, scale, hm.size(2), True)
return preds, preds_orig
def get_preds_fromhm_batch(hm, centers=None, scales=None):
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
and the scales is provided the function will return the points also in
the original coordinate frame.
Arguments:
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
Keyword Arguments:
centers {torch.tensor} -- the centers of the bounding box (default: {None})
scales {float} -- face scales (default: {None})
"""
max, idx = torch.max(
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
idx += 1
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
for i in range(preds.size(0)):
for j in range(preds.size(1)):
hm_ = hm[i, j, :]
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
diff = torch.FloatTensor(
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
preds[i, j].add_(diff.sign_().mul_(.25))
preds.add_(-.5)
preds_orig = torch.zeros(preds.size())
if centers is not None and scales is not None:
for i in range(hm.size(0)):
for j in range(hm.size(1)):
preds_orig[i, j] = transform(
preds[i, j], centers[i], scales[i], hm.size(2), True)
return preds, preds_orig
def shuffle_lr(parts, pairs=None):
"""Shuffle the points left-right according to the axis of symmetry
of the object.
Arguments:
parts {torch.tensor} -- a 3D or 4D object containing the
heatmaps.
Keyword Arguments:
pairs {list of integers} -- [order of the flipped points] (default: {None})
"""
if pairs is None:
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
62, 61, 60, 67, 66, 65]
if parts.ndimension() == 3:
parts = parts[pairs, ...]
else:
parts = parts[:, pairs, ...]
return parts
def flip(tensor, is_label=False):
"""Flip an image or a set of heatmaps left-right
Arguments:
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
Keyword Arguments:
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
"""
if not torch.is_tensor(tensor):
tensor = torch.from_numpy(tensor)
if is_label:
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
else:
tensor = tensor.flip(tensor.ndimension() - 1)
return tensor
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
def appdata_dir(appname=None, roaming=False):
""" appdata_dir(appname=None, roaming=False)
Get the path to the application directory, where applications are allowed
to write user specific files (e.g. configurations). For non-user specific
data, consider using common_appdata_dir().
If appname is given, a subdir is appended (and created if necessary).
If roaming is True, will prefer a roaming directory (Windows Vista/7).
"""
# Define default user directory
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
if userDir is None:
userDir = os.path.expanduser('~')
if not os.path.isdir(userDir): # pragma: no cover
userDir = '/var/tmp' # issue #54
# Get system app data dir
path = None
if sys.platform.startswith('win'):
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
path = (path2 or path1) if roaming else (path1 or path2)
elif sys.platform.startswith('darwin'):
path = os.path.join(userDir, 'Library', 'Application Support')
# On Linux and as fallback
if not (path and os.path.isdir(path)):
path = userDir
# Maybe we should store things local to the executable (in case of a
# portable distro or a frozen application that wants to be portable)
prefix = sys.prefix
if getattr(sys, 'frozen', None):
prefix = os.path.abspath(os.path.dirname(sys.executable))
for reldir in ('settings', '../settings'):
localpath = os.path.abspath(os.path.join(prefix, reldir))
if os.path.isdir(localpath): # pragma: no cover
try:
open(os.path.join(localpath, 'test.write'), 'wb').close()
os.remove(os.path.join(localpath, 'test.write'))
except IOError:
pass # We cannot write in this directory
else:
path = localpath
break
# Get path specific for this app
if appname:
if path == userDir:
appname = '.' + appname.lstrip('.') # Make it a hidden directory
path = os.path.join(path, appname)
if not os.path.isdir(path): # pragma: no cover
os.mkdir(path)
# Done
return path

1
filelists/README.md Normal file
View file

@ -0,0 +1 @@
Place LRS2 (and any other) filelists here for training.

86
hparams.py Normal file
View file

@ -0,0 +1,86 @@
from tensorflow.contrib.training import HParams
from glob import glob
import os
def get_image_list(data_root, split):
filelist = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line: line = line.split()[0]
filelist.append(os.path.join(data_root, line))
return filelist
# Default hyperparameters
hparams = HParams(
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
# network
rescale=True, # Whether to rescale audio prior to preprocessing
rescaling_max=0.9, # Rescaling value
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
# Does not work if n_ffit is not multiple of hop_size!!
use_lws=False,
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
# Mel and Linear spectrograms normalization/scaling and clipping
signal_normalization=True,
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
symmetric_mels=True,
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
# faster and cleaner convergence)
max_abs_value=4.,
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
# be too big to avoid gradient explosion,
# not too small for fast convergence)
# Contribution by @begeekmyfriend
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
# levels. Also allows for better G&L phase reconstruction)
preemphasize=True, # whether to apply filter
preemphasis=0.97, # filter coefficient.
# Limits
min_level_db=-100,
ref_level_db=20,
fmin=55,
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
fmax=7600, # To be increased/reduced depending on data.
###################### Our training parameters #################################
img_size=96,
fps=25,
batch_size=16,
initial_learning_rate=1e-4,
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
num_workers=16,
checkpoint_interval=3000,
eval_interval=3000,
save_optimizer_state=True,
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
syncnet_batch_size=64,
syncnet_lr=1e-4,
syncnet_eval_interval=10000,
syncnet_checkpoint_interval=10000,
disc_wt=0.07,
disc_initial_learning_rate=1e-4,
)
def hparams_debug_string():
values = hparams.values()
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
return "Hyperparameters:\n" + "\n".join(hp)

443
hq_wav2lip_train.py Normal file
View file

@ -0,0 +1,443 @@
from os.path import dirname, join, basename, isfile
from tqdm import tqdm
from models import SyncNet_color as SyncNet
from models import Wav2Lip, Wav2Lip_disc_qual
import audio
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list
parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
args = parser.parse_args()
global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
syncnet_T = 5
syncnet_mel_step_size = 16
class Dataset(object):
def __init__(self, split):
self.all_videos = get_image_list(args.data_root, split)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.jpg'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
except Exception as e:
return None
window.append(img)
return window
def crop_audio_window(self, spec, start_frame):
if type(start_frame) == int:
start_frame_num = start_frame
else:
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx : end_idx, :]
def get_segmented_mels(self, spec, start_frame):
mels = []
assert syncnet_T == 5
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
if start_frame_num - 2 < 0: return None
for i in range(start_frame_num, start_frame_num + syncnet_T):
m = self.crop_audio_window(spec, i - 2)
if m.shape[0] != syncnet_mel_step_size:
return None
mels.append(m.T)
mels = np.asarray(mels)
return mels
def prepare_window(self, window):
# 3 x T x H x W
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
vidname = self.all_videos[idx]
img_names = list(glob(join(vidname, '*.jpg')))
if len(img_names) <= 3 * syncnet_T:
continue
img_name = random.choice(img_names)
wrong_img_name = random.choice(img_names)
while wrong_img_name == img_name:
wrong_img_name = random.choice(img_names)
window_fnames = self.get_window(img_name)
wrong_window_fnames = self.get_window(wrong_img_name)
if window_fnames is None or wrong_window_fnames is None:
continue
window = self.read_window(window_fnames)
if window is None:
continue
wrong_window = self.read_window(wrong_window_fnames)
if wrong_window is None:
continue
try:
wavpath = join(vidname, "audio.wav")
wav = audio.load_wav(wavpath, hparams.sample_rate)
orig_mel = audio.melspectrogram(wav).T
except Exception as e:
continue
mel = self.crop_audio_window(orig_mel.copy(), img_name)
if (mel.shape[0] != syncnet_mel_step_size):
continue
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
if indiv_mels is None: continue
window = self.prepare_window(window)
y = window.copy()
window[:, :, window.shape[2]//2:] = 0.
wrong_window = self.prepare_window(wrong_window)
x = np.concatenate([window, wrong_window], axis=0)
x = torch.FloatTensor(x)
mel = torch.FloatTensor(mel.T).unsqueeze(0)
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
y = torch.FloatTensor(y)
return x, indiv_mels, mel, y
def save_sample_images(x, g, gt, global_step, checkpoint_dir):
x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
refs, inps = x[..., 3:], x[..., :3]
folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
if not os.path.exists(folder): os.mkdir(folder)
collage = np.concatenate((refs, inps, g, gt), axis=-2)
for batch_idx, c in enumerate(collage):
for t in range(len(c)):
cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
logloss = nn.BCELoss()
def cosine_loss(a, v, y):
d = nn.functional.cosine_similarity(a, v)
loss = logloss(d.unsqueeze(1), y)
return loss
device = torch.device("cuda" if use_cuda else "cpu")
syncnet = SyncNet().to(device)
for p in syncnet.parameters():
p.requires_grad = False
recon_loss = nn.L1Loss()
def get_sync_loss(mel, g):
g = g[:, :, :, g.size(3)//2:]
g = torch.cat([g[:, :, i] for i in range(hparams.syncnet_T)], dim=1)
# B, 3 * T, H//2, W
a, v = syncnet(mel, g)
y = torch.ones(g.size(0), 1).float().to(device)
return cosine_loss(a, v, y)
def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
global global_step, global_epoch
resumed_step = global_step
while global_epoch < nepochs:
print('Starting Epoch: {}'.format(global_epoch))
running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
running_disc_real_loss, running_disc_fake_loss = 0., 0.
prog_bar = tqdm(enumerate(train_data_loader))
for step, (x, indiv_mels, mel, gt) in prog_bar:
disc.train()
model.train()
x = x.to(device)
mel = mel.to(device)
indiv_mels = indiv_mels.to(device)
gt = gt.to(device)
### Train generator now. Remove ALL grads.
optimizer.zero_grad()
disc_optimizer.zero_grad()
g = model(indiv_mels, x)
if hparams.syncnet_wt > 0.:
sync_loss = get_sync_loss(mel, g)
else:
sync_loss = 0.
if hparams.disc_wt > 0.:
perceptual_loss = disc.perceptual_forward(g)
else:
perceptual_loss = 0.
l1loss = recon_loss(g, gt)
loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
(1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
loss.backward()
optimizer.step()
### Remove all gradients before Training disc
disc_optimizer.zero_grad()
pred = disc(gt)
disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
disc_real_loss.backward()
pred = disc(g.detach())
disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
disc_fake_loss.backward()
disc_optimizer.step()
running_disc_real_loss += disc_real_loss.item()
running_disc_fake_loss += disc_fake_loss.item()
if global_step % checkpoint_interval == 0:
save_sample_images(x, g, gt, global_step, checkpoint_dir)
# Logs
global_step += 1
cur_session_steps = global_step - resumed_step
running_l1_loss += l1loss.item()
if hparams.syncnet_wt > 0.:
running_sync_loss += sync_loss.item()
else:
running_sync_loss += 0.
if hparams.disc_wt > 0.:
running_perceptual_loss += perceptual_loss.item()
else:
running_perceptual_loss += 0.
if global_step == 1 or global_step % checkpoint_interval == 0:
save_checkpoint(
model, optimizer, global_step, checkpoint_dir, global_epoch)
save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
if global_step % hparams.eval_interval == 0:
with torch.no_grad():
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc, checkpoint_dir)
if average_sync_loss < .75:
hparams.set_hparam('syncnet_wt', 0.03)
prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
running_sync_loss / (step + 1),
running_perceptual_loss / (step + 1),
running_disc_fake_loss / (step + 1),
running_disc_real_loss / (step + 1)))
global_epoch += 1
def eval_model(test_data_loader, global_step, writer, device, model, disc, checkpoint_dir):
eval_steps = 300
print('Evaluating for {} steps'.format(eval_steps))
running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
while 1:
for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
model.eval()
disc.eval()
x = x.to(device)
mel = mel.to(device)
indiv_mels = indiv_mels.to(device)
gt = gt.to(device)
pred = disc(gt)
disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
g = model(indiv_mels, x)
pred = disc(g)
disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
running_disc_real_loss.append(disc_real_loss.item())
running_disc_fake_loss.append(disc_fake_loss.item())
sync_loss = get_sync_loss(mel, g)
if hparams.disc_wt > 0.:
perceptual_loss = disc.perceptual_forward(g)
else:
perceptual_loss = 0.
l1loss = recon_loss(g, gt)
loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
(1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
running_l1_loss.append(l1loss.item())
running_sync_loss.append(sync_loss.item())
if hparams.disc_wt > 0.:
running_perceptual_loss.append(perceptual_loss.item())
else:
running_perceptual_loss.append(0.)
if step > eval_steps: break
print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
sum(running_sync_loss) / len(running_sync_loss),
sum(running_perceptual_loss) / len(running_perceptual_loss),
sum(running_disc_fake_loss) / len(running_disc_fake_loss),
sum(running_disc_real_loss) / len(running_disc_real_loss)))
return sum(running_sync_loss) / len(running_sync_loss)
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
checkpoint_path = join(
checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
torch.save({
"state_dict": model.state_dict(),
"optimizer": optimizer_state,
"global_step": step,
"global_epoch": epoch,
}, checkpoint_path)
print("Saved checkpoint:", checkpoint_path)
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
global global_step
global global_epoch
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
if not reset_optimizer:
optimizer_state = checkpoint["optimizer"]
if optimizer_state is not None:
print("Load optimizer state from {}".format(path))
optimizer.load_state_dict(checkpoint["optimizer"])
if overwrite_global_states:
global_step = checkpoint["global_step"]
global_epoch = checkpoint["global_epoch"]
return model
if __name__ == "__main__":
checkpoint_dir = args.checkpoint_dir
# Dataset and Dataloader setup
train_dataset = Dataset('train')
test_dataset = Dataset('val')
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=hparams.batch_size, shuffle=True,
num_workers=hparams.num_workers)
test_data_loader = data_utils.DataLoader(
test_dataset, batch_size=hparams.batch_size,
num_workers=4)
device = torch.device("cuda" if use_cuda else "cpu")
# Model
model = Wav2Lip().to(device)
disc = Wav2Lip_disc_qual().to(device)
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
if args.checkpoint_path is not None:
load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
if args.disc_checkpoint_path is not None:
load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
reset_optimizer=False, overwrite_global_states=False)
load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
overwrite_global_states=False)
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
# Train!
train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
checkpoint_dir=checkpoint_dir,
checkpoint_interval=hparams.checkpoint_interval,
nepochs=hparams.nepochs)

249
inference.py Normal file
View file

@ -0,0 +1,249 @@
from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse, audio
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import torch, face_detection
from models import Wav2Lip
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
parser.add_argument('--checkpoint_path', type=str,
help='Name of saved checkpoint to load weights from', required=True)
parser.add_argument('--face', type=str,
help='Filepath of video/image that contains faces to use', required=True)
parser.add_argument('--audio', type=str,
help='Filepath of video/audio file to use as raw audio source', required=True)
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
default='results/result_voice.mp4')
parser.add_argument('--static', type=bool,
help='If True, then use only first video frame for inference', default=False)
parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
default=25., required=False)
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
parser.add_argument('--face_det_batch_size', type=int,
help='Batch size for face detection', default=16)
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
parser.add_argument('--resize_factor', default=1, type=int,
help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
args = parser.parse_args()
args.img_size = 96
if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
args.static = True
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = args.face_det_batch_size
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = args.pads
for rect, image in zip(predictions, images):
if rect is None:
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = get_smoothened_boxes(np.array(results), T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
def datagen(frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if not args.static:
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
else:
face_det_results = face_detect([frames[0]])
for i, m in enumerate(mels):
idx = 0 if args.static else i%len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (args.img_size, args.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= args.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def main():
if not os.path.isfile(args.face):
fnames = list(glob(os.path.join(args.face, '*.jpg')))
sorted_fnames = sorted(fnames, key=lambda f: int(os.path.basename(f).split('.')[0]))
full_frames = [cv2.imread(f) for f in sorted_fnames]
elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(args.face)]
fps = args.fps
else:
video_stream = cv2.VideoCapture(args.face)
fps = video_stream.get(cv2.CAP_PROP_FPS)
print('Reading video frames...')
full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
if args.resize_factor > 1:
frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
full_frames.append(frame)
print ("Number of frames available for inference: "+str(len(full_frames)))
if not args.audio.endswith('.wav'):
print('Extracting raw audio...')
command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
subprocess.call(command, shell=True)
args.audio = 'temp/temp.wav'
wav = audio.load_wav(args.audio, 16000)
mel = audio.melspectrogram(wav)
print(mel.shape)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
mel_chunks = []
mel_idx_multiplier = 80./fps
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
break
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
i += 1
print("Length of mel chunks: {}".format(len(mel_chunks)))
full_frames = full_frames[:len(mel_chunks)]
batch_size = args.wav2lip_batch_size
gen = datagen(full_frames.copy(), mel_chunks)
for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
if i == 0:
model = load_model(args.checkpoint_path)
print ("Model loaded")
frame_h, frame_w = full_frames[0].shape[:-1]
out = cv2.VideoWriter('temp/result.avi',
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = p
out.write(f)
out.release()
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
subprocess.call(command, shell=True)
if __name__ == '__main__':
main()

2
models/__init__.py Normal file
View file

@ -0,0 +1,2 @@
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
from .syncnet import SyncNet_color

44
models/conv.py Normal file
View file

@ -0,0 +1,44 @@
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class nonorm_Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
)
self.act = nn.LeakyReLU(0.01, inplace=True)
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)

66
models/syncnet.py Normal file
View file

@ -0,0 +1,66 @@
import torch
from torch import nn
from torch.nn import functional as F
from .conv import Conv2d
class SyncNet_color(nn.Module):
def __init__(self):
super(SyncNet_color, self).__init__()
self.face_encoder = nn.Sequential(
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
face_embedding = self.face_encoder(face_sequences)
audio_embedding = self.audio_encoder(audio_sequences)
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
face_embedding = face_embedding.view(face_embedding.size(0), -1)
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
face_embedding = F.normalize(face_embedding, p=2, dim=1)
return audio_embedding, face_embedding

184
models/wav2lip.py Normal file
View file

@ -0,0 +1,184 @@
import torch
from torch import nn
from torch.nn import functional as F
import math
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
class Wav2Lip(nn.Module):
def __init__(self):
super(Wav2Lip, self).__init__()
self.face_encoder_blocks = nn.ModuleList([
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
self.face_decoder_blocks = nn.ModuleList([
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sigmoid())
def forward(self, audio_sequences, face_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
input_dim_size = len(face_sequences.size())
if input_dim_size > 4:
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
feats = []
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
feats.append(x)
x = audio_embedding
for f in self.face_decoder_blocks:
x = f(x)
try:
x = torch.cat((x, feats[-1]), dim=1)
except Exception as e:
print(x.size())
print(feats[-1].size())
raise e
feats.pop()
x = self.output_block(x)
if input_dim_size > 4:
x = torch.split(x, B, dim=0) # [(B, C, H, W)]
outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
else:
outputs = x
return outputs
class Wav2Lip_disc_qual(nn.Module):
def __init__(self):
super(Wav2Lip_disc_qual, self).__init__()
self.face_encoder_blocks = nn.ModuleList([
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
self.label_noise = .0
def get_lower_half(self, face_sequences):
return face_sequences[:, :, face_sequences.size(2)//2:]
def to_2d(self, face_sequences):
B = face_sequences.size(0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
return face_sequences
def perceptual_forward(self, false_face_sequences):
false_face_sequences = self.to_2d(false_face_sequences)
false_face_sequences = self.get_lower_half(false_face_sequences)
false_feats = false_face_sequences
for f in self.face_encoder_blocks:
false_feats = f(false_feats)
false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
torch.ones((len(false_feats), 1)).cuda())
return false_pred_loss
def forward(self, face_sequences):
face_sequences = self.to_2d(face_sequences)
face_sequences = self.get_lower_half(face_sequences)
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
return self.binary_pred(x).view(len(x), -1)

112
preprocess.py Normal file
View file

@ -0,0 +1,112 @@
import sys
if sys.version_info[0] < 3 and sys.version_info[1] < 2:
raise Exception("Must be using >= Python 3.2")
from os import listdir, path
if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
raise FileNotFoundError('Save the s3fd model to face_detection/sfd/s3fd.pth \
before running this script!')
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import argparse, os, cv2, traceback, subprocess
from tqdm import tqdm
from glob import glob
import audio
from hparams import hparams as hp
import face_detection
parser = argparse.ArgumentParser()
parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
args = parser.parse_args()
fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
device='cuda:{}'.format(id)) for id in range(args.ngpu)]
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
def process_video_file(vfile, args, gpu_id):
video_stream = cv2.VideoCapture(vfile)
frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
frames.append(frame)
vidname = os.path.basename(vfile).split('.')[0]
dirname = vfile.split('/')[-2]
fulldir = path.join(args.preprocessed_root, dirname, vidname)
os.makedirs(fulldir, exist_ok=True)
batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
i = -1
for fb in batches:
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
for j, f in enumerate(preds):
i += 1
if f is None:
continue
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), f[0])
def process_audio_file(vfile, args):
vidname = os.path.basename(vfile).split('.')[0]
dirname = vfile.split('/')[-2]
fulldir = path.join(args.preprocessed_root, dirname, vidname)
os.makedirs(fulldir, exist_ok=True)
wavpath = path.join(fulldir, 'audio.wav')
command = template.format(vfile, wavpath)
subprocess.call(command, shell=True)
def mp_handler(job):
vfile, args, gpu_id = job
try:
process_video_file(vfile, args, gpu_id)
except KeyboardInterrupt:
exit(0)
except:
traceback.print_exc()
def main(args):
print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
filelist = glob(path.join(args.data_root, '*/*.mp4'))
jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
p = ThreadPoolExecutor(args.ngpu)
futures = [p.submit(mp_handler, j) for j in jobs]
_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
print('Dumping audios...')
for vfile in tqdm(filelist):
try:
process_audio_file(vfile, args)
except KeyboardInterrupt:
exit(0)
except:
traceback.print_exc()
continue
if __name__ == '__main__':
main(args)

8
requirements.txt Normal file
View file

@ -0,0 +1,8 @@
librosa==0.7.0
numpy==1.17.1
opencv-contrib-python==4.2.0.34
opencv-python==4.1.0.25
tensorflow-gpu==1.10.0
torch==1.1.0
torchvision==0.3.0
tqdm==4.45.0

1
results/README.md Normal file
View file

@ -0,0 +1 @@
Generated results will be placed in this folder by default.

1
temp/README.md Normal file
View file

@ -0,0 +1 @@
Temporary files at the time of inference/testing will be saved here. You can ignore them.

374
wav2lip_train.py Normal file
View file

@ -0,0 +1,374 @@
from os.path import dirname, join, basename, isfile
from tqdm import tqdm
from models import SyncNet_color as SyncNet
from models import Wav2Lip as Wav2Lip
import audio
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list
parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
args = parser.parse_args()
global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
syncnet_T = 5
syncnet_mel_step_size = 16
class Dataset(object):
def __init__(self, split):
self.all_videos = get_image_list(args.data_root, split)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.jpg'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.resize(img, (hparams.img_size, hparams.img_size))
except Exception as e:
return None
window.append(img)
return window
def crop_audio_window(self, spec, start_frame):
if type(start_frame) == int:
start_frame_num = start_frame
else:
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx : end_idx, :]
def get_segmented_mels(self, spec, start_frame):
mels = []
assert syncnet_T == 5
start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
if start_frame_num - 2 < 0: return None
for i in range(start_frame_num, start_frame_num + syncnet_T):
m = self.crop_audio_window(spec, i - 2)
if m.shape[0] != syncnet_mel_step_size:
return None
mels.append(m.T)
mels = np.asarray(mels)
return mels
def prepare_window(self, window):
# 3 x T x H x W
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
vidname = self.all_videos[idx]
img_names = list(glob(join(vidname, '*.jpg')))
if len(img_names) <= 3 * syncnet_T:
continue
img_name = random.choice(img_names)
wrong_img_name = random.choice(img_names)
while wrong_img_name == img_name:
wrong_img_name = random.choice(img_names)
window_fnames = self.get_window(img_name)
wrong_window_fnames = self.get_window(wrong_img_name)
if window_fnames is None or wrong_window_fnames is None:
continue
window = self.read_window(window_fnames)
if window is None:
continue
wrong_window = self.read_window(wrong_window_fnames)
if wrong_window is None:
continue
try:
wavpath = join(vidname, "audio.wav")
wav = audio.load_wav(wavpath, hparams.sample_rate)
orig_mel = audio.melspectrogram(wav).T
except Exception as e:
continue
mel = self.crop_audio_window(orig_mel.copy(), img_name)
if (mel.shape[0] != syncnet_mel_step_size):
continue
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
if indiv_mels is None: continue
window = self.prepare_window(window)
y = window.copy()
window[:, :, window.shape[2]//2:] = 0.
wrong_window = self.prepare_window(wrong_window)
x = np.concatenate([window, wrong_window], axis=0)
x = torch.FloatTensor(x)
mel = torch.FloatTensor(mel.T).unsqueeze(0)
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
y = torch.FloatTensor(y)
return x, indiv_mels, mel, y
def save_sample_images(x, g, gt, global_step, checkpoint_dir):
x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
refs, inps = x[..., 3:], x[..., :3]
folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
if not os.path.exists(folder): os.mkdir(folder)
collage = np.concatenate((refs, inps, g, gt), axis=-2)
for batch_idx, c in enumerate(collage):
for t in range(len(c)):
cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
logloss = nn.BCELoss()
def cosine_loss(a, v, y):
d = nn.functional.cosine_similarity(a, v)
loss = logloss(d.unsqueeze(1), y)
return loss
device = torch.device("cuda" if use_cuda else "cpu")
syncnet = SyncNet().to(device)
for p in syncnet.parameters():
p.requires_grad = False
recon_loss = nn.L1Loss()
def get_sync_loss(mel, g):
g = g[:, :, :, g.size(3)//2:]
g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
# B, 3 * T, H//2, W
a, v = syncnet(mel, g)
y = torch.ones(g.size(0), 1).float().to(device)
return cosine_loss(a, v, y)
def train(device, model, train_data_loader, test_data_loader, optimizer,
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
global global_step, global_epoch
resumed_step = global_step
while global_epoch < nepochs:
print('Starting Epoch: {}'.format(global_epoch))
running_sync_loss, running_l1_loss = 0., 0.
prog_bar = tqdm(enumerate(train_data_loader))
for step, (x, indiv_mels, mel, gt) in prog_bar:
model.train()
optimizer.zero_grad()
# Move data to CUDA device
x = x.to(device)
mel = mel.to(device)
indiv_mels = indiv_mels.to(device)
gt = gt.to(device)
g = model(indiv_mels, x)
if hparams.syncnet_wt > 0.:
sync_loss = get_sync_loss(mel, g)
else:
sync_loss = 0.
l1loss = recon_loss(g, gt)
loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
loss.backward()
optimizer.step()
if global_step % checkpoint_interval == 0:
save_sample_images(x, g, gt, global_step, checkpoint_dir)
global_step += 1
cur_session_steps = global_step - resumed_step
running_l1_loss += l1loss.item()
if hparams.syncnet_wt > 0.:
running_sync_loss += sync_loss.item()
else:
running_sync_loss += 0.
if global_step == 1 or global_step % checkpoint_interval == 0:
save_checkpoint(
model, optimizer, global_step, checkpoint_dir, global_epoch)
if global_step == 1 or global_step % hparams.eval_interval == 0:
with torch.no_grad():
average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
if average_sync_loss < .75:
hparams.set_hparam('syncnet_wt', 0.03)
prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
running_sync_loss / (step + 1)))
global_epoch += 1
def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
eval_steps = 700
print('Evaluating for {} steps'.format(eval_steps))
sync_losses, recon_losses = [], []
step = 0
while 1:
for x, indiv_mels, mel, gt in test_data_loader:
step += 1
model.eval()
# Move data to CUDA device
x = x.to(device)
gt = gt.to(device)
indiv_mels = indiv_mels.to(device)
mel = mel.to(device)
g = model(indiv_mels, x)
sync_loss = get_sync_loss(mel, g)
l1loss = recon_loss(g, gt)
sync_losses.append(sync_loss.item())
recon_losses.append(l1loss.item())
if step > eval_steps:
averaged_sync_loss = sum(sync_losses) / len(sync_losses)
averaged_recon_loss = sum(recon_losses) / len(recon_losses)
print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
return averaged_sync_loss
def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
checkpoint_path = join(
checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
torch.save({
"state_dict": model.state_dict(),
"optimizer": optimizer_state,
"global_step": step,
"global_epoch": epoch,
}, checkpoint_path)
print("Saved checkpoint:", checkpoint_path)
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
global global_step
global global_epoch
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
if not reset_optimizer:
optimizer_state = checkpoint["optimizer"]
if optimizer_state is not None:
print("Load optimizer state from {}".format(path))
optimizer.load_state_dict(checkpoint["optimizer"])
if overwrite_global_states:
global_step = checkpoint["global_step"]
global_epoch = checkpoint["global_epoch"]
return model
if __name__ == "__main__":
checkpoint_dir = args.checkpoint_dir
# Dataset and Dataloader setup
train_dataset = Dataset('train')
test_dataset = Dataset('val')
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=hparams.batch_size, shuffle=True,
num_workers=hparams.num_workers)
test_data_loader = data_utils.DataLoader(
test_dataset, batch_size=hparams.batch_size,
num_workers=4)
device = torch.device("cuda" if use_cuda else "cpu")
# Model
model = Wav2Lip().to(device)
print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
lr=hparams.initial_learning_rate)
if args.checkpoint_path is not None:
load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
# Train!
train(device, model, train_data_loader, test_data_loader, optimizer,
checkpoint_dir=checkpoint_dir,
checkpoint_interval=hparams.checkpoint_interval,
nepochs=hparams.nepochs)