mirror of
https://github.com/NVIDIA/vid2vid.git
synced 2026-02-01 17:26:51 +00:00
add face and pose code
This commit is contained in:
parent
944d67d706
commit
e2d623dfa4
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
debug*
|
||||
checkpoints/
|
||||
datasets/
|
||||
models/flownet2*
|
||||
results/
|
||||
build/
|
||||
|
||||
160
README.md
160
README.md
@ -42,78 +42,123 @@ Pytorch implementation for high-resolution (e.g., 2048x1024) photorealistic vide
|
||||
```bash
|
||||
pip install dominate requests
|
||||
```
|
||||
- If you plan to train with face datasets, please install dlib.
|
||||
```bash
|
||||
pip install dlib
|
||||
```
|
||||
- If you plan to train with pose datasets, please install [DensePose](https://github.com/facebookresearch/DensePose) and/or [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose).
|
||||
- Clone this repo:
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA/vid2vid
|
||||
cd vid2vid
|
||||
```
|
||||
|
||||
### Testing
|
||||
- We include an example Cityscapes video in the `datasets` folder.
|
||||
- First, download and compile a snapshot of [FlowNet2](https://github.com/NVIDIA/flownet2-pytorch) by running `python scripts/download_flownet2.py`.
|
||||
- Please download the pre-trained Cityscapes model by:
|
||||
```bash
|
||||
python scripts/download_models.py
|
||||
```
|
||||
- To test the model (`bash ./scripts/test_2048.sh`):
|
||||
```bash
|
||||
#!./scripts/test_2048.sh
|
||||
python test.py --name label2city_2048 --dataroot datasets/Cityscapes/test_A --loadSize 2048 --n_scales_spatial 3 --use_instance --fg --use_single_G
|
||||
```
|
||||
The test results will be saved to a HTML file here: `./results/label2city_2048/test_latest/index.html`.
|
||||
### Testing
|
||||
- Please first download example dataset by running `python scripts/download_datasets.py`.
|
||||
- Next, download and compile a snapshot of [FlowNet2](https://github.com/NVIDIA/flownet2-pytorch) by running `python scripts/download_flownet2.py`.
|
||||
- Cityscapes
|
||||
- Please download the pre-trained Cityscapes model by:
|
||||
```bash
|
||||
python scripts/street/download_models.py
|
||||
```
|
||||
- To test the model (`bash ./scripts/street/test_2048.sh`):
|
||||
```bash
|
||||
#!./scripts/street/test_2048.sh
|
||||
python test.py --name label2city_2048 --label_nc 35 --loadSize 2048 --n_scales_spatial 3 --use_instance --fg --use_single_G
|
||||
```
|
||||
The test results will be saved in: `./results/label2city_2048/test_latest/`.
|
||||
|
||||
- We also provide a smaller model trained with single GPU, which produces slightly worse performance at 1024 x 512 resolution.
|
||||
- Please download the model by
|
||||
```bash
|
||||
python scripts/download_models_g1.py
|
||||
```
|
||||
- To test the model (`bash ./scripts/test_1024_g1.sh`):
|
||||
```bash
|
||||
#!./scripts/test_1024_g1.sh
|
||||
python test.py --name label2city_1024_g1 --dataroot datasets/Cityscapes/test_A --loadSize 1024 --n_scales_spatial 3 --use_instance --fg --n_downsample_G 2 --use_single_G
|
||||
```
|
||||
- We also provide a smaller model trained with single GPU, which produces slightly worse performance at 1024 x 512 resolution.
|
||||
- Please download the model by
|
||||
```bash
|
||||
python scripts/street/download_models_g1.py
|
||||
```
|
||||
- To test the model (`bash ./scripts/street/test_g1_1024.sh`):
|
||||
```bash
|
||||
#!./scripts/street/test_g1_1024.sh
|
||||
python test.py --name label2city_1024_g1 --label_nc 35 --loadSize 1024 --n_scales_spatial 3 --use_instance --fg --n_downsample_G 2 --use_single_G
|
||||
```
|
||||
- You can find more example scripts in the `scripts/street/` directory.
|
||||
|
||||
- You can find more example scripts in the `scripts` directory.
|
||||
- Faces
|
||||
- Please download the pre-trained model by:
|
||||
```bash
|
||||
python scripts/face/download_models.py
|
||||
```
|
||||
- To test the model (`bash ./scripts/face/test_512.sh`):
|
||||
```bash
|
||||
#!./scripts/face/test_512.sh
|
||||
python test.py --name edge2face_512 --dataroot datasets/face/ --dataset_mode face --input_nc 15 --loadSize 512 --use_single_G
|
||||
```
|
||||
The test results will be saved in: `./results/edge2face_512/test_latest/`.
|
||||
|
||||
### Dataset
|
||||
- We use the Cityscapes dataset as an example. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
|
||||
- We apply a pre-trained segmentation algorithm to get the corresponding semantic maps (train_A) and instance maps (train_inst).
|
||||
- Please add the obtained images to the `datasets` folder in the same way the example images are provided.
|
||||
- Cityscapes
|
||||
- We use the Cityscapes dataset as an example. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
|
||||
- We apply a pre-trained segmentation algorithm to get the corresponding semantic maps (train_A) and instance maps (train_inst).
|
||||
- Please add the obtained images to the `datasets` folder in the same way the example images are provided.
|
||||
- Face
|
||||
- We use the [FaceForensics](http://niessnerlab.org/projects/roessler2018faceforensics.html) dataset. We then use landmark detection to estimate the face keypoints, and interpolate them to get face edges.
|
||||
- Pose
|
||||
- We use random dancing videos found on YouTube. We then apply DensePose / OpenPose to estimate the poses for each frame.
|
||||
|
||||
|
||||
### Training
|
||||
### Training with Cityscapes dataset
|
||||
- First, download the FlowNet2 checkpoint file by running `python scripts/download_models_flownet2.py`.
|
||||
- Training with 8 GPUs:
|
||||
- We adopt a coarse-to-fine approach, sequentially increasing the resolution from 512 x 256, 1024 x 512, to 2048 x 1024.
|
||||
- Train a model at 512 x 256 resolution (`bash ./scripts/train_512.sh`)
|
||||
- Train a model at 512 x 256 resolution (`bash ./scripts/street/train_512.sh`)
|
||||
```bash
|
||||
#!./scripts/train_512.sh
|
||||
python train.py --name label2city_512 --gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 --n_frames_total 6 --use_instance --fg
|
||||
#!./scripts/street/train_512.sh
|
||||
python train.py --name label2city_512 --label_nc 35 --gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 --n_frames_total 6 --use_instance --fg
|
||||
```
|
||||
- Train a model at 1024 x 512 resolution (must train 512 x 256 first) (`bash ./scripts/train_1024.sh`):
|
||||
- Train a model at 1024 x 512 resolution (must train 512 x 256 first) (`bash ./scripts/street/train_1024.sh`):
|
||||
```bash
|
||||
#!./scripts/train_1024.sh
|
||||
python train.py --name label2city_1024 --loadSize 1024 --n_scales_spatial 2 --num_D 3 --gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 4 --use_instance --fg --niter_step 2 --niter_fix_global 10 --load_pretrain checkpoints/label2city_512
|
||||
#!./scripts/street/train_1024.sh
|
||||
python train.py --name label2city_1024 --label_nc 35 --loadSize 1024 --n_scales_spatial 2 --num_D 3 --gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 4 --use_instance --fg --niter_step 2 --niter_fix_global 10 --load_pretrain checkpoints/label2city_512
|
||||
```
|
||||
- To view training results, please checkout intermediate results in `./checkpoints/label2city_1024/web/index.html`.
|
||||
If you have TensorFlow installed, you can see TensorBoard logs in `./checkpoints/label2city_1024/logs` by adding `--tf_log` to the training scripts.
|
||||
|
||||
- Training with a single GPU:
|
||||
- We trained our models using multiple GPUs. For convenience, we provide some sample training scripts (XXX_g1.sh) for single GPU users, up to 1024 x 512 resolution. Again a coarse-to-fine approach is adopted (256 x 128, 512 x 256, 1024 x 512). Performance is not guaranteed using these scripts.
|
||||
- For example, to train a 256 x 128 video with a single GPU (`bash ./scripts/train_256_g1.sh`)
|
||||
- We trained our models using multiple GPUs. For convenience, we provide some sample training scripts (train_g1_XXX.sh) for single GPU users, up to 1024 x 512 resolution. Again a coarse-to-fine approach is adopted (256 x 128, 512 x 256, 1024 x 512). Performance is not guaranteed using these scripts.
|
||||
- For example, to train a 256 x 128 video with a single GPU (`bash ./scripts/street/train_g1_256.sh`)
|
||||
```bash
|
||||
#!./scripts/train_256_g1.sh
|
||||
python train.py --name label2city_256_g1 --loadSize 256 --use_instance --fg --n_downsample_G 2 --num_D 1 --max_frames_per_gpu 6 --n_frames_total 6
|
||||
#!./scripts/street/train_g1_256.sh
|
||||
python train.py --name label2city_256_g1 --label_nc 35 --loadSize 256 --use_instance --fg --n_downsample_G 2 --num_D 1 --max_frames_per_gpu 6 --n_frames_total 6
|
||||
```
|
||||
|
||||
### Training at full (2k x 1k) resolution
|
||||
- To train the images at full resolution (2048 x 1024) requires 8 GPUs with at least 24G memory (`bash ./scripts/train_2048.sh`).
|
||||
If only GPUs with 12G/16G memory are available, please use the script `./scripts/train_2048_crop.sh`, which will crop the images during training. Performance is not guaranteed with this script.
|
||||
- Training at full (2k x 1k) resolution
|
||||
- To train the images at full resolution (2048 x 1024) requires 8 GPUs with at least 24G memory (`bash ./scripts/street/train_2048.sh`). If only GPUs with 12G/16G memory are available, please use the script `./scripts/street/train_2048_crop.sh`, which will crop the images during training. Performance is not guaranteed with this script.
|
||||
|
||||
### Training with face datasets
|
||||
- If you haven't, please first download example dataset by running `python scripts/download_datasets.py`.
|
||||
- Run the following command to compute face landmarks for training dataset:
|
||||
```bash
|
||||
python data/face_landmark_detection.py train
|
||||
```
|
||||
- Run the example script (`bash ./scripts/face/train_512.sh`)
|
||||
```bash
|
||||
python train.py --name edge2face_512 --dataroot datasets/face/ --dataset_mode face --input_nc 15 --loadSize 512 --num_D 3 --gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 --n_frames_total 12
|
||||
```
|
||||
- For single GPU users, example scripts are in train_g1_XXX.sh. These scripts are not fully tested and please use at your own discretion. If you still hit out of memory errors, try reducing `max_frames_per_gpu`.
|
||||
- More examples scripts can be found in `scripts/face/`.
|
||||
- Please refer to [More Training/Test Details](https://github.com/NVIDIA/vid2vid#more-trainingtest-details) for more explanations about training flags.
|
||||
|
||||
|
||||
### Training with pose datasets
|
||||
- If you haven't, please first download example dataset by running `python scripts/download_datasets.py`.
|
||||
- Example DensePose and OpenPose results are included. If you plan to use your own dataset, please generate these results and put them in the same way the example dataset is provided.
|
||||
- Run the example script (`bash ./scripts/pose/train_256p.sh`)
|
||||
```bash
|
||||
python train.py --name pose2body_256p --dataroot datasets/pose --dataset_mode pose --input_nc 6 --num_D 2 --resize_or_crop ScaleHeight_and_scaledCrop --loadSize 384 --fineSize 256 --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 8 --max_frames_per_gpu 3 --no_first_img --n_frames_total 12 --max_t_step 4
|
||||
```
|
||||
- Again, for single GPU users, example scripts are in train_g1_XXX.sh. These scripts are not fully tested and please use at your own discretion. If you still hit out of memory errors, try reducing `max_frames_per_gpu`.
|
||||
- More examples scripts can be found in `scripts/pose/`.
|
||||
- Please refer to [More Training/Test Details](https://github.com/NVIDIA/vid2vid#more-trainingtest-details) for more explanations about training flags.
|
||||
|
||||
### Training with your own dataset
|
||||
- If your input is a label map, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please use `--label_nc N` during both training and testing.
|
||||
- If your input is not a label map, please specify `--label_nc 0` and `--input_nc N` where N is the number of input channels (The default is 3 for RGB images).
|
||||
- The default setting for preprocessing is `scaleWidth`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scaleWidth_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. `scaledCrop` crops the image while retraining the original aspect ratio. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
|
||||
- If your input is not a label map, please specify `--input_nc N` where N is the number of input channels (The default is 3 for RGB images).
|
||||
- The default setting for preprocessing is `scaleWidth`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scaleWidth_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. `scaledCrop` crops the image while retraining the original aspect ratio. `randomScaleHeight` will randomly scale the image height to be between `opt.loadSize` and `opt.fineSize`. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
|
||||
|
||||
## More Training/Test Details
|
||||
- We generate frames in the video sequentially, where the generation of the current frame depends on previous frames. To generate the first frame for the model, there are 3 different ways:
|
||||
@ -127,7 +172,7 @@ If only GPUs with 12G/16G memory are available, please use the script `./scripts
|
||||
- `n_frames_D`: the number of frames to feed into the temporal discriminator. The default is 3.
|
||||
- `n_scales_spatial`: the number of scales in the spatial domain. We train from the coarsest scale and all the way to the finest scale. The default is 3.
|
||||
- `n_scales_temporal`: the number of scales for the temporal discriminator. The finest scale takes in the sequence in the original frame rate. The coarser scales subsample the frames by a factor of `n_frames_D` before feeding the frames into the discriminator. For example, if `n_frames_D = 3` and `n_scales_temporal = 3`, the discriminator effectively sees 27 frames. The default is 3.
|
||||
- `max_frames_per_gpu`: the number of frames in one GPU during training. If your GPU memory can fit more frames, try to make this number bigger. The default is 1.
|
||||
- `max_frames_per_gpu`: the number of frames in one GPU during training. If you run into out of memory error, please first try to reduce this number. If your GPU memory can fit more frames, try to make this number bigger to make training faster. The default is 1.
|
||||
- `max_frames_backpropagate`: the number of frames that loss backpropagates to previous frames. For example, if this number is 4, the loss on frame n will backpropagate to frame n-3. Increasing this number will slightly improve the performance, but also cause training to be less stable. The default is 1.
|
||||
- `n_frames_total`: the total number of frames in a sequence we want to train with. We gradually increase this number during training.
|
||||
- `niter_step`: for how many epochs do we double `n_frames_total`. The default is 5.
|
||||
@ -135,18 +180,31 @@ If only GPUs with 12G/16G memory are available, please use the script `./scripts
|
||||
- `batchSize`: the number of sequences to train at a time. We normally set batchSize to 1 since often, one sequence is enough to occupy all GPUs. If you want to do batchSize > 1, currently only `batchSize == n_gpus_gen` is supported.
|
||||
- `no_first_img`: if not specified, the model will assume the first frame is given and synthesize the successive frames. If specified, the model will also try to synthesize the first frame instead.
|
||||
- `fg`: if specified, use the foreground-background separation model as stated in the paper. The foreground labels must be specified by `--fg_labels`.
|
||||
- `no_flow`: if specified, do not use flow warping and directly synthesize frames. We found this usually still works reasonably well when the background is static, while saving memory and training time.
|
||||
- For other flags, please see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
|
||||
|
||||
- Additional flags for edge2face examples:
|
||||
- `no_canny_edge`: do not use canny edges for background as input.
|
||||
- `no_dist_map`: by default, we use distrance transform on the face edge map as input. This flag will make it directly use edge maps.
|
||||
|
||||
- Additional flags for pose2body examples:
|
||||
- `densepose_only`: use only densepose results as input. Please also remember to change `input_nc` to be 3.
|
||||
- `openpose_only`: use only openpose results as input. Please also remember to change `input_nc` to be 3.
|
||||
- `add_face_disc`: add an additional discriminator that only works on the face region.
|
||||
- `remove_face_labels`: remove densepose results for face, and add noise to openpose face results, so the network can get more robust to different face shapes. This is important if you plan to do inference on half-body videos (if not, usually this flag is unnecessary).
|
||||
- `random_drop_prob`: the probability to randomly drop each pose segment during training, so the network can get more robust to missing poses at inference time. Default is 0.2.
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this useful for your research, please cite the following paper.
|
||||
|
||||
```
|
||||
@article{wang2018vid2vid,
|
||||
title={Video-to-Video Synthesis},
|
||||
author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Guilin Liu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
|
||||
journal={arXiv preprint arXiv:1808.06601},
|
||||
year={2018}
|
||||
@inproceedings{wang2018vid2vid,
|
||||
author = {Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Guilin Liu
|
||||
and Andrew Tao and Jan Kautz and Bryan Catanzaro},
|
||||
title = {Video-to-Video Synthesis},
|
||||
booktitle = {Advances in Neural Information Processing Systems (NIPS)},
|
||||
year = {2018},
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch.utils.data as data
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
@ -14,45 +15,104 @@ class BaseDataset(data.Dataset):
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if 'resize' in opt.resize_or_crop:
|
||||
new_h = new_w = opt.loadSize
|
||||
elif 'scaleWidth' in opt.resize_or_crop:
|
||||
new_w = opt.loadSize
|
||||
new_h = opt.loadSize * h / w
|
||||
def update_training_batch(self, ratio): # update the training sequence length to be longer
|
||||
seq_len_max = min(128, self.seq_len_max) - (self.opt.n_frames_G - 1)
|
||||
if self.n_frames_total < seq_len_max:
|
||||
self.n_frames_total = min(seq_len_max, self.opt.n_frames_total * (2**ratio))
|
||||
#self.n_frames_total = min(seq_len_max, self.opt.n_frames_total * (ratio + 1))
|
||||
print('--------- Updating training sequence length to %d ---------' % self.n_frames_total)
|
||||
|
||||
if 'crop' in opt.resize_or_crop:
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
|
||||
elif 'scaledCrop' in opt.resize_or_crop:
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize*new_h//new_w))
|
||||
def init_frame_idx(self, A_paths):
|
||||
self.n_of_seqs = len(A_paths) # number of sequences to train
|
||||
self.seq_len_max = max([len(A) for A in A_paths]) # max number of frames in the training sequences
|
||||
|
||||
self.seq_idx = 0 # index for current sequence
|
||||
self.frame_idx = self.opt.start_frame if not self.opt.isTrain else 0 # index for current frame in the sequence
|
||||
self.frames_count = [] # number of frames in each sequence
|
||||
for path in A_paths:
|
||||
self.frames_count.append(len(path) - self.opt.n_frames_G + 1)
|
||||
|
||||
self.folder_prob = [count / sum(self.frames_count) for count in self.frames_count]
|
||||
self.n_frames_total = self.opt.n_frames_total if self.opt.isTrain else 1
|
||||
self.A, self.B, self.I = None, None, None
|
||||
|
||||
def update_frame_idx(self, A_paths, index):
|
||||
if self.opt.isTrain:
|
||||
if self.opt.dataset_mode == 'pose':
|
||||
seq_idx = np.random.choice(len(A_paths), p=self.folder_prob) # randomly pick sequence to train
|
||||
else:
|
||||
seq_idx = index % self.n_of_seqs
|
||||
return None, None, None, seq_idx
|
||||
else:
|
||||
self.change_seq = self.frame_idx >= self.frames_count[self.seq_idx]
|
||||
if self.change_seq:
|
||||
self.seq_idx += 1
|
||||
self.frame_idx = 0
|
||||
self.A, self.B, self.I = None, None, None
|
||||
return self.A, self.B, self.I, self.seq_idx
|
||||
|
||||
def make_power_2(n, base=32.0):
|
||||
return int(round(n / base) * base)
|
||||
|
||||
def get_img_params(opt, size):
|
||||
w, h = size
|
||||
new_h, new_w = h, w
|
||||
if 'resize' in opt.resize_or_crop: # resize image to be loadSize x loadSize
|
||||
new_h = new_w = opt.loadSize
|
||||
elif 'scaleWidth' in opt.resize_or_crop: # scale image width to be loadSize
|
||||
new_w = opt.loadSize
|
||||
new_h = opt.loadSize * h // w
|
||||
elif 'scaleHeight' in opt.resize_or_crop: # scale image height to be loadSize
|
||||
new_h = opt.loadSize
|
||||
new_w = opt.loadSize * w // h
|
||||
elif 'randomScaleWidth' in opt.resize_or_crop: # randomly scale image width to be somewhere between loadSize and fineSize
|
||||
new_w = random.randint(opt.fineSize, opt.loadSize + 1)
|
||||
new_h = new_w * h // w
|
||||
elif 'randomScaleHeight' in opt.resize_or_crop: # randomly scale image height to be somewhere between loadSize and fineSize
|
||||
new_h = random.randint(opt.fineSize, opt.loadSize + 1)
|
||||
new_w = new_h * w // h
|
||||
new_w = int(round(new_w / 4)) * 4
|
||||
new_h = int(round(new_h / 4)) * 4
|
||||
|
||||
crop_x = crop_y = 0
|
||||
crop_w = crop_h = 0
|
||||
if 'crop' in opt.resize_or_crop or 'scaledCrop' in opt.resize_or_crop:
|
||||
if 'crop' in opt.resize_or_crop: # crop patches of size fineSize x fineSize
|
||||
crop_w = crop_h = opt.fineSize
|
||||
else:
|
||||
if 'Width' in opt.resize_or_crop: # crop patches of width fineSize
|
||||
crop_w = opt.fineSize
|
||||
crop_h = opt.fineSize * h // w
|
||||
else: # crop patches of height fineSize
|
||||
crop_h = opt.fineSize
|
||||
crop_w = opt.fineSize * w // h
|
||||
|
||||
crop_w, crop_h = make_power_2(crop_w), make_power_2(crop_h)
|
||||
x_span = (new_w - crop_w) // 2
|
||||
crop_x = np.maximum(0, np.minimum(x_span*2, int(np.random.randn() * x_span/3 + x_span)))
|
||||
crop_y = random.randint(0, np.minimum(np.maximum(0, new_h - crop_h), new_h // 8))
|
||||
#crop_x = random.randint(0, np.maximum(0, new_w - crop_w))
|
||||
#crop_y = random.randint(0, np.maximum(0, new_h - crop_h))
|
||||
else:
|
||||
x = y = 0
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {'crop_pos': (x,y), 'flip': flip}
|
||||
new_w, new_h = make_power_2(new_w), make_power_2(new_h)
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {'new_size': (new_w, new_h), 'crop_size': (crop_w, crop_h), 'crop_pos': (crop_x, crop_y), 'flip': flip}
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
|
||||
transform_list = []
|
||||
### resize input image
|
||||
if 'resize' in opt.resize_or_crop:
|
||||
osize = [opt.loadSize, opt.loadSize]
|
||||
transform_list.append(transforms.Scale(osize, method))
|
||||
elif 'scaleWidth' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_image(img, opt.loadSize, method)))
|
||||
else:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_image(img, params['new_size'], method)))
|
||||
|
||||
if 'crop' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
||||
elif 'scaledCrop' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize, False)))
|
||||
|
||||
elif opt.resize_or_crop == 'none':
|
||||
base = 32
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
### crop patches from image
|
||||
if 'crop' in opt.resize_or_crop or 'scaledCrop' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_size'], params['crop_pos'])))
|
||||
|
||||
### random flip
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
@ -69,51 +129,56 @@ def toTensor_normalize():
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def __scale_image(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if ow > oh:
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
else:
|
||||
h = target_width
|
||||
w = int(target_width * ow / oh)
|
||||
base = 32.0
|
||||
h = int(round(h / base) * base)
|
||||
w = int(round(w / base) * base)
|
||||
def __scale_image(img, size, method=Image.BICUBIC):
|
||||
w, h = size
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
def __crop(img, size, pos):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
base = 32.0
|
||||
h = int(round(h / base) * base)
|
||||
w = int(round(w / base) * base)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __crop(img, pos, size, square=True):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if not square:
|
||||
th = th * oh // ow
|
||||
tw, th = size
|
||||
x1, y1 = pos
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, min(ow, x1 + tw), min(oh, y1 + th)))
|
||||
return img
|
||||
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
||||
def get_video_params(opt, n_frames_total, cur_seq_len, index):
|
||||
tG = opt.n_frames_G
|
||||
if opt.isTrain:
|
||||
n_frames_total = min(n_frames_total, cur_seq_len - tG + 1)
|
||||
|
||||
n_gpus = opt.n_gpus_gen // opt.batchSize # number of generator GPUs for each batch
|
||||
n_frames_per_load = opt.max_frames_per_gpu * n_gpus # number of frames to load into GPUs at one time (for each batch)
|
||||
n_frames_per_load = min(n_frames_total, n_frames_per_load)
|
||||
n_loadings = n_frames_total // n_frames_per_load # how many times are needed to load entire sequence into GPUs
|
||||
n_frames_total = n_frames_per_load * n_loadings + tG - 1 # rounded overall number of frames to read from the sequence
|
||||
|
||||
max_t_step = min(opt.max_t_step, (cur_seq_len-1) // (n_frames_total-1))
|
||||
t_step = np.random.randint(max_t_step) + 1 # spacing between neighboring sampled frames
|
||||
offset_max = max(1, cur_seq_len - (n_frames_total-1)*t_step) # maximum possible index for the first frame
|
||||
if opt.dataset_mode == 'pose':
|
||||
start_idx = index# % offset_max
|
||||
else:
|
||||
start_idx = np.random.randint(offset_max) # offset for the first frame to load
|
||||
if opt.debug:
|
||||
print("loading %d frames in total, first frame starting at index %d, space between neighboring frames is %d"
|
||||
% (n_frames_total, start_idx, t_step))
|
||||
else:
|
||||
n_frames_total = tG
|
||||
start_idx = index
|
||||
t_step = 1
|
||||
return n_frames_total, start_idx, t_step
|
||||
|
||||
def concat_frame(A, Ai, nF):
|
||||
if A is None:
|
||||
A = Ai
|
||||
else:
|
||||
c = Ai.size()[0]
|
||||
if A.size()[0] == nF * c:
|
||||
A = A[c:]
|
||||
A = torch.cat([A, Ai])
|
||||
return A
|
||||
@ -10,6 +10,9 @@ def CreateDataset(opt):
|
||||
elif opt.dataset_mode == 'face':
|
||||
from data.face_dataset import FaceDataset
|
||||
dataset = FaceDataset()
|
||||
elif opt.dataset_mode == 'pose':
|
||||
from data.pose_dataset import PoseDataset
|
||||
dataset = PoseDataset()
|
||||
elif opt.dataset_mode == 'test':
|
||||
from data.test_dataset import TestDataset
|
||||
dataset = TestDataset()
|
||||
|
||||
172
data/face_dataset.py
Executable file
172
data/face_dataset.py
Executable file
@ -0,0 +1,172 @@
|
||||
import os.path
|
||||
import torchvision.transforms as transforms
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
from skimage import feature
|
||||
|
||||
from data.base_dataset import BaseDataset, get_img_params, get_transform, get_video_params, concat_frame
|
||||
from data.image_folder import make_grouped_dataset, check_path_valid
|
||||
from data.keypoint2img import interpPoints, drawEdge
|
||||
|
||||
class FaceDataset(BaseDataset):
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + '_keypoints')
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + '_img')
|
||||
|
||||
self.A_paths = sorted(make_grouped_dataset(self.dir_A))
|
||||
self.B_paths = sorted(make_grouped_dataset(self.dir_B))
|
||||
check_path_valid(self.A_paths, self.B_paths)
|
||||
|
||||
self.init_frame_idx(self.A_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
A, B, I, seq_idx = self.update_frame_idx(self.A_paths, index)
|
||||
A_paths = self.A_paths[seq_idx]
|
||||
B_paths = self.B_paths[seq_idx]
|
||||
n_frames_total, start_idx, t_step = get_video_params(self.opt, self.n_frames_total, len(A_paths), self.frame_idx)
|
||||
|
||||
B_img = Image.open(B_paths[0]).convert('RGB')
|
||||
B_size = B_img.size
|
||||
points = np.loadtxt(A_paths[0], delimiter=',')
|
||||
is_first_frame = self.opt.isTrain or not hasattr(self, 'min_x')
|
||||
if is_first_frame: # crop only the face region
|
||||
self.get_crop_coords(points, B_size)
|
||||
params = get_img_params(self.opt, self.crop(B_img).size)
|
||||
transform_scaleA = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)
|
||||
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
||||
transform_scaleB = get_transform(self.opt, params)
|
||||
|
||||
# read in images
|
||||
frame_range = list(range(n_frames_total)) if self.A is None else [self.opt.n_frames_G-1]
|
||||
for i in frame_range:
|
||||
A_path = A_paths[start_idx + i * t_step]
|
||||
B_path = B_paths[start_idx + i * t_step]
|
||||
B_img = Image.open(B_path)
|
||||
Ai, Li = self.get_face_image(A_path, transform_scaleA, transform_label, B_size, B_img)
|
||||
Bi = transform_scaleB(self.crop(B_img))
|
||||
A = concat_frame(A, Ai, n_frames_total)
|
||||
B = concat_frame(B, Bi, n_frames_total)
|
||||
I = concat_frame(I, Li, n_frames_total)
|
||||
|
||||
if not self.opt.isTrain:
|
||||
self.A, self.B, self.I = A, B, I
|
||||
self.frame_idx += 1
|
||||
change_seq = False if self.opt.isTrain else self.change_seq
|
||||
return_list = {'A': A, 'B': B, 'inst': I, 'A_path': A_path, 'change_seq': change_seq}
|
||||
|
||||
return return_list
|
||||
|
||||
def get_image(self, A_path, transform_scaleA):
|
||||
A_img = Image.open(A_path)
|
||||
A_scaled = transform_scaleA(self.crop(A_img))
|
||||
return A_scaled
|
||||
|
||||
def get_face_image(self, A_path, transform_A, transform_L, size, img):
|
||||
# read face keypoints from path and crop face region
|
||||
keypoints, part_list, part_labels = self.read_keypoints(A_path, size)
|
||||
|
||||
# draw edges and possibly add distance transform maps
|
||||
add_dist_map = not self.opt.no_dist_map
|
||||
im_edges, dist_tensor = self.draw_face_edges(keypoints, part_list, transform_A, size, add_dist_map)
|
||||
|
||||
# canny edge for background
|
||||
if not self.opt.no_canny_edge:
|
||||
edges = feature.canny(np.array(img.convert('L')))
|
||||
edges = edges * (part_labels == 0) # remove edges within face
|
||||
im_edges += (edges * 255).astype(np.uint8)
|
||||
edge_tensor = transform_A(Image.fromarray(self.crop(im_edges)))
|
||||
|
||||
# final input tensor
|
||||
input_tensor = torch.cat([edge_tensor, dist_tensor]) if add_dist_map else edge_tensor
|
||||
label_tensor = transform_L(Image.fromarray(self.crop(part_labels.astype(np.uint8)))) * 255.0
|
||||
return input_tensor, label_tensor
|
||||
|
||||
def read_keypoints(self, A_path, size):
|
||||
# mapping from keypoints to face part
|
||||
part_list = [[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
|
||||
[range(17, 22)], # right eyebrow
|
||||
[range(22, 27)], # left eyebrow
|
||||
[[28, 31], range(31, 36), [35, 28]], # nose
|
||||
[[36,37,38,39], [39,40,41,36]], # right eye
|
||||
[[42,43,44,45], [45,46,47,42]], # left eye
|
||||
[range(48, 55), [54,55,56,57,58,59,48]], # mouth
|
||||
[range(60, 65), [64,65,66,67,60]] # tongue
|
||||
]
|
||||
label_list = [1, 2, 2, 3, 4, 4, 5, 6] # labeling for different facial parts
|
||||
keypoints = np.loadtxt(A_path, delimiter=',')
|
||||
|
||||
# add upper half face by symmetry
|
||||
pts = keypoints[:17, :].astype(np.int32)
|
||||
baseline_y = (pts[0,1] + pts[-1,1]) / 2
|
||||
upper_pts = pts[1:-1,:].copy()
|
||||
upper_pts[:,1] = baseline_y + (baseline_y-upper_pts[:,1]) * 2 // 3
|
||||
keypoints = np.vstack((keypoints, upper_pts[::-1,:]))
|
||||
|
||||
# label map for facial part
|
||||
w, h = size
|
||||
part_labels = np.zeros((h, w), np.uint8)
|
||||
for p, edge_list in enumerate(part_list):
|
||||
indices = [item for sublist in edge_list for item in sublist]
|
||||
pts = keypoints[indices, :].astype(np.int32)
|
||||
cv2.fillPoly(part_labels, pts=[pts], color=label_list[p])
|
||||
|
||||
return keypoints, part_list, part_labels
|
||||
|
||||
def draw_face_edges(self, keypoints, part_list, transform_A, size, add_dist_map):
|
||||
w, h = size
|
||||
edge_len = 3 # interpolate 3 keypoints to form a curve when drawing edges
|
||||
# edge map for face region from keypoints
|
||||
im_edges = np.zeros((h, w), np.uint8) # edge map for all edges
|
||||
dist_tensor = 0
|
||||
e = 1
|
||||
for edge_list in part_list:
|
||||
for edge in edge_list:
|
||||
im_edge = np.zeros((h, w), np.uint8) # edge map for the current edge
|
||||
for i in range(0, max(1, len(edge)-1), edge_len-1): # divide a long edge into multiple small edges when drawing
|
||||
sub_edge = edge[i:i+edge_len]
|
||||
x = keypoints[sub_edge, 0]
|
||||
y = keypoints[sub_edge, 1]
|
||||
|
||||
curve_x, curve_y = interpPoints(x, y) # interp keypoints to get the curve shape
|
||||
drawEdge(im_edges, curve_x, curve_y)
|
||||
if add_dist_map:
|
||||
drawEdge(im_edge, curve_x, curve_y)
|
||||
|
||||
if add_dist_map: # add distance transform map on each facial part
|
||||
im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3)
|
||||
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
||||
im_dist = Image.fromarray(im_dist)
|
||||
tensor_cropped = transform_A(self.crop(im_dist))
|
||||
dist_tensor = tensor_cropped if e == 1 else torch.cat([dist_tensor, tensor_cropped])
|
||||
e += 1
|
||||
|
||||
return im_edges, dist_tensor
|
||||
|
||||
def get_crop_coords(self, keypoints, size):
|
||||
min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max()
|
||||
min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max()
|
||||
offset = (max_x - min_x) // 2
|
||||
min_y = max(0, min_y - offset*2)
|
||||
min_x = max(0, min_x - offset)
|
||||
max_x = min(size[0], max_x + offset)
|
||||
max_y = min(size[1], max_y + offset)
|
||||
self.min_y, self.max_y, self.min_x, self.max_x = int(min_y), int(max_y), int(min_x), int(max_x)
|
||||
|
||||
def crop(self, img):
|
||||
if isinstance(img, np.ndarray):
|
||||
return img[self.min_y:self.max_y, self.min_x:self.max_x]
|
||||
else:
|
||||
return img.crop((self.min_x, self.min_y, self.max_x, self.max_y))
|
||||
|
||||
def __len__(self):
|
||||
if self.opt.isTrain:
|
||||
return len(self.A_paths)
|
||||
else:
|
||||
return sum(self.frames_count)
|
||||
|
||||
def name(self):
|
||||
return 'FaceDataset'
|
||||
37
data/face_landmark_detection.py
Executable file
37
data/face_landmark_detection.py
Executable file
@ -0,0 +1,37 @@
|
||||
import os
|
||||
import glob
|
||||
from skimage import io
|
||||
import numpy as np
|
||||
import dlib
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2 or (sys.argv[1] != 'train' and sys.argv[1] != 'test'):
|
||||
raise ValueError('usage: python data/face_landmark_detection.py [train|test]')
|
||||
|
||||
phase = sys.argv[1]
|
||||
dataset_path = 'datasets/face/'
|
||||
faces_folder_path = os.path.join(dataset_path, phase + '_img/')
|
||||
predictor_path = os.path.join(dataset_path, 'shape_predictor_68_face_landmarks.dat')
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
predictor = dlib.shape_predictor(predictor_path)
|
||||
|
||||
img_paths = sorted(glob.glob(faces_folder_path + '*'))
|
||||
for i in range(len(img_paths)):
|
||||
f = img_paths[i]
|
||||
print("Processing video: {}".format(f))
|
||||
save_path = os.path.join(dataset_path, phase + '_keypoints', os.path.basename(f))
|
||||
if not os.path.isdir(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
for img_name in sorted(glob.glob(os.path.join(f, '*.jpg'))):
|
||||
img = io.imread(img_name)
|
||||
dets = detector(img, 1)
|
||||
if len(dets) > 0:
|
||||
shape = predictor(img, dets[0])
|
||||
points = np.empty([68, 2], dtype=int)
|
||||
for b in range(68):
|
||||
points[b,0] = shape.part(b).x
|
||||
points[b,1] = shape.part(b).y
|
||||
|
||||
save_name = os.path.join(save_path, os.path.basename(img_name)[:-4] + '.txt')
|
||||
np.savetxt(save_name, points, fmt='%d', delimiter=',')
|
||||
@ -13,7 +13,8 @@ import os.path
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG', '.pgm', '.PGM',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.txt'
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff',
|
||||
'.txt', '.json'
|
||||
]
|
||||
|
||||
|
||||
@ -46,6 +47,11 @@ def make_grouped_dataset(dir):
|
||||
images.append(paths)
|
||||
return images
|
||||
|
||||
def check_path_valid(A_paths, B_paths):
|
||||
assert(len(A_paths) == len(B_paths))
|
||||
for a, b in zip(A_paths, B_paths):
|
||||
assert(len(a) == len(b))
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
191
data/keypoint2img.py
Executable file
191
data/keypoint2img.py
Executable file
@ -0,0 +1,191 @@
|
||||
import os.path
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import json
|
||||
import glob
|
||||
from scipy.optimize import curve_fit
|
||||
import warnings
|
||||
|
||||
def func(x, a, b, c):
|
||||
return a * x**2 + b * x + c
|
||||
|
||||
def linear(x, a, b):
|
||||
return a * x + b
|
||||
|
||||
def setColor(im, yy, xx, color):
|
||||
if len(im.shape) == 3:
|
||||
if (im[yy, xx] == 0).all():
|
||||
im[yy, xx, 0], im[yy, xx, 1], im[yy, xx, 2] = color[0], color[1], color[2]
|
||||
else:
|
||||
im[yy, xx, 0] = ((im[yy, xx, 0].astype(float) + color[0]) / 2).astype(np.uint8)
|
||||
im[yy, xx, 1] = ((im[yy, xx, 1].astype(float) + color[1]) / 2).astype(np.uint8)
|
||||
im[yy, xx, 2] = ((im[yy, xx, 2].astype(float) + color[2]) / 2).astype(np.uint8)
|
||||
else:
|
||||
im[yy, xx] = color[0]
|
||||
|
||||
def drawEdge(im, x, y, bw=1, color=(255,255,255), draw_end_points=False):
|
||||
if x is not None and x.size:
|
||||
h, w = im.shape[0], im.shape[1]
|
||||
# edge
|
||||
for i in range(-bw, bw):
|
||||
for j in range(-bw, bw):
|
||||
yy = np.maximum(0, np.minimum(h-1, y+i))
|
||||
xx = np.maximum(0, np.minimum(w-1, x+j))
|
||||
setColor(im, yy, xx, color)
|
||||
|
||||
# edge endpoints
|
||||
if draw_end_points:
|
||||
for i in range(-bw*2, bw*2):
|
||||
for j in range(-bw*2, bw*2):
|
||||
if (i**2) + (j**2) < (4 * bw**2):
|
||||
yy = np.maximum(0, np.minimum(h-1, np.array([y[0], y[-1]])+i))
|
||||
xx = np.maximum(0, np.minimum(w-1, np.array([x[0], x[-1]])+j))
|
||||
setColor(im, yy, xx, color)
|
||||
|
||||
def interpPoints(x, y):
|
||||
if abs(x[:-1] - x[1:]).max() < abs(y[:-1] - y[1:]).max():
|
||||
curve_y, curve_x = interpPoints(y, x)
|
||||
if curve_y is None:
|
||||
return None, None
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
if len(x) < 3:
|
||||
popt, _ = curve_fit(linear, x, y)
|
||||
else:
|
||||
popt, _ = curve_fit(func, x, y)
|
||||
if abs(popt[0]) > 1:
|
||||
return None, None
|
||||
if x[0] > x[-1]:
|
||||
x = list(reversed(x))
|
||||
y = list(reversed(y))
|
||||
curve_x = np.linspace(x[0], x[-1], (x[-1]-x[0]))
|
||||
if len(x) < 3:
|
||||
curve_y = linear(curve_x, *popt)
|
||||
else:
|
||||
curve_y = func(curve_x, *popt)
|
||||
return curve_x.astype(int), curve_y.astype(int)
|
||||
|
||||
def read_keypoints(json_input, size, random_drop_prob=0, remove_face_labels=False):
|
||||
with open(json_input, encoding='utf-8') as f:
|
||||
keypoint_dicts = json.loads(f.read())["people"]
|
||||
|
||||
edge_lists = define_edge_lists()
|
||||
w, h = size
|
||||
pose_img = np.zeros((h, w, 3), np.uint8)
|
||||
for keypoint_dict in keypoint_dicts:
|
||||
pose_pts = np.array(keypoint_dict["pose_keypoints_2d"]).reshape(25, 3)
|
||||
face_pts = np.array(keypoint_dict["face_keypoints_2d"]).reshape(70, 3)
|
||||
hand_pts_l = np.array(keypoint_dict["hand_left_keypoints_2d"]).reshape(21, 3)
|
||||
hand_pts_r = np.array(keypoint_dict["hand_right_keypoints_2d"]).reshape(21, 3)
|
||||
pts = [extract_valid_keypoints(pts, edge_lists) for pts in [pose_pts, face_pts, hand_pts_l, hand_pts_r]]
|
||||
pose_img += connect_keypoints(pts, edge_lists, size, random_drop_prob, remove_face_labels)
|
||||
return pose_img
|
||||
|
||||
def extract_valid_keypoints(pts, edge_lists):
|
||||
pose_edge_list, _, hand_edge_list, _, face_list = edge_lists
|
||||
p = pts.shape[0]
|
||||
thre = 0.1 if p == 70 else 0.01
|
||||
output = np.zeros((p, 2))
|
||||
|
||||
if p == 70: # face
|
||||
for edge_list in face_list:
|
||||
for edge in edge_list:
|
||||
if (pts[edge, 2] > thre).all():
|
||||
output[edge, :] = pts[edge, :2]
|
||||
elif p == 21: # hand
|
||||
for edge in hand_edge_list:
|
||||
if (pts[edge, 2] > thre).all():
|
||||
output[edge, :] = pts[edge, :2]
|
||||
else: # pose
|
||||
valid = (pts[:, 2] > thre)
|
||||
output[valid, :] = pts[valid, :2]
|
||||
|
||||
return output
|
||||
|
||||
def connect_keypoints(pts, edge_lists, size, random_drop_prob, remove_face_labels):
|
||||
pose_pts, face_pts, hand_pts_l, hand_pts_r = pts
|
||||
w, h = size
|
||||
output_edges = np.zeros((h, w, 3), np.uint8)
|
||||
pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, face_list = edge_lists
|
||||
|
||||
if random_drop_prob > 0 and remove_face_labels:
|
||||
# add random noise to keypoints
|
||||
pose_pts[[0,15,16,17,18], :] += 5 * np.random.randn(5,2)
|
||||
face_pts[:,0] += 2 * np.random.randn()
|
||||
face_pts[:,1] += 2 * np.random.randn()
|
||||
|
||||
### pose
|
||||
for i, edge in enumerate(pose_edge_list):
|
||||
x, y = pose_pts[edge, 0], pose_pts[edge, 1]
|
||||
if (np.random.rand() > random_drop_prob) and (0 not in x):
|
||||
curve_x, curve_y = interpPoints(x, y)
|
||||
drawEdge(output_edges, curve_x, curve_y, bw=3, color=pose_color_list[i], draw_end_points=True)
|
||||
|
||||
### hand
|
||||
for hand_pts in [hand_pts_l, hand_pts_r]: # for left and right hand
|
||||
if np.random.rand() > random_drop_prob:
|
||||
for i, edge in enumerate(hand_edge_list): # for each finger
|
||||
for j in range(0, len(edge)-1): # for each part of the finger
|
||||
sub_edge = edge[j:j+2]
|
||||
x, y = hand_pts[sub_edge, 0], hand_pts[sub_edge, 1]
|
||||
if 0 not in x:
|
||||
line_x, line_y = interpPoints(x, y)
|
||||
drawEdge(output_edges, line_x, line_y, bw=1, color=hand_color_list[i], draw_end_points=True)
|
||||
|
||||
### face
|
||||
edge_len = 2
|
||||
if (np.random.rand() > random_drop_prob):
|
||||
for edge_list in face_list:
|
||||
for edge in edge_list:
|
||||
for i in range(0, max(1, len(edge)-1), edge_len-1):
|
||||
sub_edge = edge[i:i+edge_len]
|
||||
x, y = face_pts[sub_edge, 0], face_pts[sub_edge, 1]
|
||||
if 0 not in x:
|
||||
curve_x, curve_y = interpPoints(x, y)
|
||||
drawEdge(output_edges, curve_x, curve_y, draw_end_points=True)
|
||||
|
||||
return output_edges
|
||||
|
||||
def define_edge_lists():
|
||||
### pose
|
||||
pose_edge_list = [
|
||||
[17, 15], [15, 0], [ 0, 16], [16, 18], [ 0, 1], # head
|
||||
[ 1, 8], # body
|
||||
[ 1, 2], [ 2, 3], [ 3, 4], # right arm
|
||||
[ 1, 5], [ 5, 6], [ 6, 7], # left arm
|
||||
[ 8, 9], [ 9, 10], [10, 11], [11, 24], [11, 22], [22, 23], # right leg
|
||||
[ 8, 12], [12, 13], [13, 14], [14, 21], [14, 19], [19, 20] # left leg
|
||||
]
|
||||
pose_color_list = [
|
||||
[153, 0,153], [153, 0,102], [102, 0,153], [ 51, 0,153], [153, 0, 51],
|
||||
[153, 0, 0],
|
||||
[153, 51, 0], [153,102, 0], [153,153, 0],
|
||||
[102,153, 0], [ 51,153, 0], [ 0,153, 0],
|
||||
[ 0,153, 51], [ 0,153,102], [ 0,153,153], [ 0,153,153], [ 0,153,153], [ 0,153,153],
|
||||
[ 0,102,153], [ 0, 51,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153]
|
||||
]
|
||||
|
||||
### hand
|
||||
hand_edge_list = [
|
||||
[0, 1, 2, 3, 4],
|
||||
[0, 5, 6, 7, 8],
|
||||
[0, 9, 10, 11, 12],
|
||||
[0, 13, 14, 15, 16],
|
||||
[0, 17, 18, 19, 20]
|
||||
]
|
||||
hand_color_list = [
|
||||
[204,0,0], [163,204,0], [0,204,82], [0,82,204], [163,0,204]
|
||||
]
|
||||
|
||||
### face
|
||||
face_list = [
|
||||
#[range(0, 17)], # face
|
||||
[range(17, 22)], # left eyebrow
|
||||
[range(22, 27)], # right eyebrow
|
||||
[range(27, 31), range(31, 36)], # nose
|
||||
[[36,37,38,39], [39,40,41,36]], # left eye
|
||||
[[42,43,44,45], [45,46,47,42]], # right eye
|
||||
[range(48, 55), [54,55,56,57,58,59,48]], # mouth
|
||||
]
|
||||
return pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, face_list
|
||||
156
data/pose_dataset.py
Executable file
156
data/pose_dataset.py
Executable file
@ -0,0 +1,156 @@
|
||||
import os.path
|
||||
import torchvision.transforms as transforms
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from data.base_dataset import BaseDataset, get_img_params, get_transform, get_video_params, concat_frame
|
||||
from data.image_folder import make_grouped_dataset, check_path_valid
|
||||
from data.keypoint2img import read_keypoints
|
||||
|
||||
class PoseDataset(BaseDataset):
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
|
||||
self.dir_dp = os.path.join(opt.dataroot, opt.phase + '_densepose')
|
||||
self.dir_op = os.path.join(opt.dataroot, opt.phase + '_openpose')
|
||||
self.dir_img = os.path.join(opt.dataroot, opt.phase + '_img')
|
||||
self.img_paths = sorted(make_grouped_dataset(self.dir_img))
|
||||
if not opt.openpose_only:
|
||||
self.dp_paths = sorted(make_grouped_dataset(self.dir_dp))
|
||||
check_path_valid(self.dp_paths, self.img_paths)
|
||||
if not opt.densepose_only:
|
||||
self.op_paths = sorted(make_grouped_dataset(self.dir_op))
|
||||
check_path_valid(self.op_paths, self.img_paths)
|
||||
|
||||
self.init_frame_idx(self.img_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
A, B, _, seq_idx = self.update_frame_idx(self.img_paths, index)
|
||||
img_paths = self.img_paths[seq_idx]
|
||||
n_frames_total, start_idx, t_step = get_video_params(self.opt, self.n_frames_total, len(img_paths), self.frame_idx)
|
||||
|
||||
img = Image.open(img_paths[0]).convert('RGB')
|
||||
size = img.size
|
||||
params = get_img_params(self.opt, size)
|
||||
|
||||
frame_range = list(range(n_frames_total)) if (self.opt.isTrain or self.A is None) else [self.opt.n_frames_G-1]
|
||||
for i in frame_range:
|
||||
img_path = img_paths[start_idx + i * t_step]
|
||||
if not self.opt.openpose_only:
|
||||
dp_path = self.dp_paths[seq_idx][start_idx + i * t_step]
|
||||
Di = self.get_image(dp_path, size, params, input_type='densepose')
|
||||
Di[2,:,:] = Di[2,:,:] * 255 / 24
|
||||
if not self.opt.densepose_only:
|
||||
op_path = self.op_paths[seq_idx][start_idx + i * t_step]
|
||||
Oi = self.get_image(op_path, size, params, input_type='openpose')
|
||||
|
||||
if self.opt.openpose_only:
|
||||
Ai = Oi
|
||||
elif self.opt.densepose_only:
|
||||
Ai = Di
|
||||
else:
|
||||
Ai = torch.cat([Di, Oi])
|
||||
Bi = self.get_image(img_path, size, params, input_type='img')
|
||||
|
||||
Ai, Bi = self.crop(Ai), self.crop(Bi) # only crop the central half region to save time
|
||||
A = concat_frame(A, Ai, n_frames_total)
|
||||
B = concat_frame(B, Bi, n_frames_total)
|
||||
|
||||
if not self.opt.isTrain:
|
||||
self.A, self.B = A, B
|
||||
self.frame_idx += 1
|
||||
change_seq = False if self.opt.isTrain else self.change_seq
|
||||
return_list = {'A': A, 'B': B, 'inst': 0, 'A_path': img_path, 'change_seq': change_seq}
|
||||
|
||||
return return_list
|
||||
|
||||
def get_image(self, A_path, size, params, input_type):
|
||||
if input_type != 'openpose':
|
||||
A_img = Image.open(A_path).convert('RGB')
|
||||
else:
|
||||
random_drop_prob = self.opt.random_drop_prob if self.opt.isTrain else 0
|
||||
A_img = Image.fromarray(read_keypoints(A_path, size, random_drop_prob, self.opt.remove_face_labels))
|
||||
|
||||
if input_type == 'densepose' and self.opt.isTrain:
|
||||
# randomly remove labels
|
||||
A_np = np.array(A_img)
|
||||
part_labels = A_np[:,:,2]
|
||||
for part_id in range(1, 25):
|
||||
if (np.random.rand() < self.opt.random_drop_prob):
|
||||
A_np[(part_labels == part_id), :] = 0
|
||||
if self.opt.remove_face_labels:
|
||||
A_np[(part_labels == 23) | (part_labels == 24), :] = 0
|
||||
A_img = Image.fromarray(A_np)
|
||||
|
||||
is_img = input_type == 'img'
|
||||
method = Image.BICUBIC if is_img else Image.NEAREST
|
||||
transform_scaleA = get_transform(self.opt, params, normalize=is_img, method=method)
|
||||
A_scaled = transform_scaleA(A_img)
|
||||
return A_scaled
|
||||
|
||||
def crop(self, Ai):
|
||||
w = Ai.size()[2]
|
||||
base = 32
|
||||
x_cen = w // 2
|
||||
bs = int(w * 0.25) // base * base
|
||||
return Ai[:,:,(x_cen-bs):(x_cen+bs)]
|
||||
|
||||
def normalize_pose(self, A_img, target_yc, target_len, first=False):
|
||||
w, h = A_img.size
|
||||
A_np = np.array(A_img)
|
||||
|
||||
if first == True:
|
||||
part_labels = A_np[:,:,2]
|
||||
part_coords = np.nonzero((part_labels == 1) | (part_labels == 2))
|
||||
y, x = part_coords[0], part_coords[1]
|
||||
|
||||
ys, ye = y.min(), y.max()
|
||||
min_i, max_i = np.argmin(y), np.argmax(y)
|
||||
v_min = A_np[y[min_i], x[min_i], 1] / 255
|
||||
v_max = A_np[y[max_i], x[max_i], 1] / 255
|
||||
ylen = (ye-ys) / (v_max-v_min)
|
||||
yc = (0.5-v_min) / (v_max-v_min) * (ye-ys) + ys
|
||||
|
||||
ratio = target_len / ylen
|
||||
offset_y = int(yc - (target_yc / ratio))
|
||||
offset_x = int(w * (1 - 1/ratio) / 2)
|
||||
|
||||
padding = int(max(0, max(-offset_y, int(offset_y + h/ratio) - h)))
|
||||
padding = int(max(padding, max(-offset_x, int(offset_x + w/ratio) - w)))
|
||||
offset_y += padding
|
||||
offset_x += padding
|
||||
self.offset_y, self.offset_x = offset_y, offset_x
|
||||
self.ratio, self.padding = ratio, padding
|
||||
|
||||
p = self.padding
|
||||
A_np = np.pad(A_np, ((p,p),(p,p),(0,0)), 'constant', constant_values=0)
|
||||
A_np = A_np[self.offset_y:int(self.offset_y + h/self.ratio), self.offset_x:int(self.offset_x + w/self.ratio):, :]
|
||||
A_img = Image.fromarray(A_np)
|
||||
A_img = A_img.resize((w, h))
|
||||
return A_img
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.frames_count)
|
||||
|
||||
def name(self):
|
||||
return 'PoseDataset'
|
||||
|
||||
"""
|
||||
DensePose label
|
||||
0 = Background
|
||||
1, 2 = Torso
|
||||
3 = Right Hand
|
||||
4 = Left Hand
|
||||
5 = Right Foot
|
||||
6 = Left Foot
|
||||
7, 9 = Upper Leg Right
|
||||
8, 10 = Upper Leg Left
|
||||
11, 13 = Lower Leg Right
|
||||
12, 14 = Lower Leg Left
|
||||
15, 17 = Upper Arm Left
|
||||
16, 18 = Upper Arm Right
|
||||
19, 21 = Lower Arm Left
|
||||
20, 22 = Lower Arm Right
|
||||
23, 24 = Head """
|
||||
@ -3,8 +3,8 @@
|
||||
import os.path
|
||||
import random
|
||||
import torch
|
||||
from data.base_dataset import BaseDataset, get_params, get_transform
|
||||
from data.image_folder import make_grouped_dataset
|
||||
from data.base_dataset import BaseDataset, get_img_params, get_transform, get_video_params
|
||||
from data.image_folder import make_grouped_dataset, check_path_valid
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
@ -18,48 +18,29 @@ class TemporalDataset(BaseDataset):
|
||||
|
||||
self.A_paths = sorted(make_grouped_dataset(self.dir_A))
|
||||
self.B_paths = sorted(make_grouped_dataset(self.dir_B))
|
||||
assert(len(self.A_paths) == len(self.B_paths))
|
||||
check_path_valid(self.A_paths, self.B_paths)
|
||||
if opt.use_instance:
|
||||
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
|
||||
self.I_paths = sorted(make_grouped_dataset(self.dir_inst))
|
||||
assert(len(self.A_paths) == len(self.I_paths))
|
||||
check_path_valid(self.A_paths, self.I_paths)
|
||||
|
||||
self.n_of_seqs = len(self.A_paths) # number of sequences to train
|
||||
self.seq_len_max = len(self.A_paths[0]) # max number of frames in the training sequences
|
||||
for i in range(1, self.n_of_seqs):
|
||||
self.seq_len_max = max(self.seq_len_max, len(self.A_paths[i]))
|
||||
self.seq_len_max = max([len(A) for A in self.A_paths])
|
||||
self.n_frames_total = self.opt.n_frames_total # current number of frames to train in a single iteration
|
||||
|
||||
def __getitem__(self, index):
|
||||
tG = self.opt.n_frames_G
|
||||
tG = self.opt.n_frames_G
|
||||
A_paths = self.A_paths[index % self.n_of_seqs]
|
||||
B_paths = self.B_paths[index % self.n_of_seqs]
|
||||
assert(len(A_paths) == len(B_paths))
|
||||
B_paths = self.B_paths[index % self.n_of_seqs]
|
||||
if self.opt.use_instance:
|
||||
I_paths = self.I_paths[index % self.n_of_seqs]
|
||||
assert(len(A_paths) == len(I_paths))
|
||||
I_paths = self.I_paths[index % self.n_of_seqs]
|
||||
|
||||
# setting parameters
|
||||
cur_seq_len = len(A_paths)
|
||||
n_frames_total = min(self.n_frames_total, cur_seq_len - tG + 1)
|
||||
|
||||
n_gpus = self.opt.n_gpus_gen // self.opt.batchSize # number of generator GPUs for each batch
|
||||
n_frames_per_load = self.opt.max_frames_per_gpu * n_gpus # number of frames to load into GPUs at one time (for each batch)
|
||||
n_frames_per_load = min(n_frames_total, n_frames_per_load)
|
||||
n_loadings = n_frames_total // n_frames_per_load # how many times are needed to load entire sequence into GPUs
|
||||
n_frames_total = n_frames_per_load * n_loadings + tG - 1 # rounded overall number of frames to read from the sequence
|
||||
|
||||
#t_step_max = min(1, (cur_seq_len-1) // (n_frames_total-1))
|
||||
#t_step = np.random.randint(t_step_max) + 1 # spacing between neighboring sampled frames
|
||||
t_step = 1
|
||||
offset_max = max(1, cur_seq_len - (n_frames_total-1)*t_step) # maximum possible index for the first frame
|
||||
start_idx = np.random.randint(offset_max) # offset for the first frame to load
|
||||
if self.opt.debug:
|
||||
print("loading %d frames in total, first frame starting at index %d" % (n_frames_total, start_idx))
|
||||
n_frames_total, start_idx, t_step = get_video_params(self.opt, self.n_frames_total, len(A_paths), index)
|
||||
|
||||
# setting transformers
|
||||
B_img = Image.open(B_paths[0]).convert('RGB')
|
||||
params = get_params(self.opt, B_img.size)
|
||||
params = get_img_params(self.opt, B_img.size)
|
||||
transform_scaleB = get_transform(self.opt, params)
|
||||
transform_scaleA = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) if self.A_is_label else transform_scaleB
|
||||
|
||||
@ -68,9 +49,7 @@ class TemporalDataset(BaseDataset):
|
||||
for i in range(n_frames_total):
|
||||
A_path = A_paths[start_idx + i * t_step]
|
||||
B_path = B_paths[start_idx + i * t_step]
|
||||
Ai = self.get_image(A_path, transform_scaleA)
|
||||
if self.A_is_label:
|
||||
Ai = Ai * 255.0
|
||||
Ai = self.get_image(A_path, transform_scaleA, is_label=self.A_is_label)
|
||||
Bi = self.get_image(B_path, transform_scaleB)
|
||||
|
||||
A = Ai if i == 0 else torch.cat([A, Ai], dim=0)
|
||||
@ -81,21 +60,16 @@ class TemporalDataset(BaseDataset):
|
||||
Ii = self.get_image(I_path, transform_scaleA) * 255.0
|
||||
inst = Ii if i == 0 else torch.cat([inst, Ii], dim=0)
|
||||
|
||||
return_list = {'A': A, 'B': B, 'inst': inst, 'A_paths': A_path, 'B_paths': B_path}
|
||||
return_list = {'A': A, 'B': B, 'inst': inst, 'A_path': A_path, 'B_paths': B_path}
|
||||
return return_list
|
||||
|
||||
def get_image(self, A_path, transform_scaleA):
|
||||
def get_image(self, A_path, transform_scaleA, is_label=False):
|
||||
A_img = Image.open(A_path)
|
||||
A_scaled = transform_scaleA(A_img)
|
||||
A_scaled = transform_scaleA(A_img)
|
||||
if is_label:
|
||||
A_scaled *= 255.0
|
||||
return A_scaled
|
||||
|
||||
def update_training_batch(self, ratio): # update the training sequence length to be longer
|
||||
seq_len_max = min(128, self.seq_len_max) - (self.opt.n_frames_G - 1)
|
||||
if self.n_frames_total < seq_len_max:
|
||||
self.n_frames_total = min(seq_len_max, self.opt.n_frames_total * (2**ratio))
|
||||
#self.n_frames_total = min(seq_len_max, self.opt.n_frames_total * (ratio + 1))
|
||||
print('--------- Updating training sequence length to %d ---------' % self.n_frames_total)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.A_paths)
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import os.path
|
||||
import torch
|
||||
from data.base_dataset import BaseDataset, get_params, get_transform
|
||||
from data.image_folder import make_grouped_dataset
|
||||
from data.base_dataset import BaseDataset, get_img_params, get_transform, concat_frame
|
||||
from data.image_folder import make_grouped_dataset, check_path_valid
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
@ -11,63 +11,60 @@ class TestDataset(BaseDataset):
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
self.dir_A = opt.dataroot
|
||||
self.dir_B = opt.dataroot.replace('test_A', 'test_B')
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + '_A')
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + '_B')
|
||||
self.use_real = opt.use_real_img
|
||||
self.A_is_label = self.opt.label_nc != 0
|
||||
|
||||
self.A_paths = sorted(make_grouped_dataset(self.dir_A))
|
||||
if self.use_real:
|
||||
self.B_paths = sorted(make_grouped_dataset(self.dir_B))
|
||||
assert(len(self.A_paths) == len(self.B_paths))
|
||||
check_path_valid(self.A_paths, self.B_paths)
|
||||
if self.opt.use_instance:
|
||||
self.dir_inst = opt.dataroot.replace('test_A', 'test_inst')
|
||||
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
|
||||
self.I_paths = sorted(make_grouped_dataset(self.dir_inst))
|
||||
assert(len(self.A_paths) == len(self.I_paths))
|
||||
check_path_valid(self.A_paths, self.I_paths)
|
||||
|
||||
self.seq_idx = 0
|
||||
self.frame_idx = 0
|
||||
self.frames_count = []
|
||||
for path in self.A_paths:
|
||||
self.frames_count.append(len(path) - opt.n_frames_G + 1)
|
||||
self.init_frame_idx(self.A_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
self.A, self.B, self.I, seq_idx = self.update_frame_idx(self.A_paths, index)
|
||||
tG = self.opt.n_frames_G
|
||||
change_seq = self.frame_idx >= self.frames_count[self.seq_idx]
|
||||
if change_seq:
|
||||
self.seq_idx += 1
|
||||
self.frame_idx = 0
|
||||
|
||||
A_img = Image.open(self.A_paths[self.seq_idx][0]).convert('RGB')
|
||||
params = get_params(self.opt, A_img.size)
|
||||
A_img = Image.open(self.A_paths[seq_idx][0]).convert('RGB')
|
||||
params = get_img_params(self.opt, A_img.size)
|
||||
transform_scaleB = get_transform(self.opt, params)
|
||||
transform_scaleA = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) if self.A_is_label else transform_scaleB
|
||||
|
||||
A = B = inst = 0
|
||||
for i in range(tG):
|
||||
A_path = self.A_paths[self.seq_idx][self.frame_idx + i]
|
||||
Ai = self.get_image(A_path, transform_scaleA)
|
||||
if self.A_is_label:
|
||||
Ai = Ai * 255.0
|
||||
A = Ai if i == 0 else torch.cat([A, Ai], dim=0)
|
||||
frame_range = list(range(tG)) if self.A is None else [tG-1]
|
||||
|
||||
for i in frame_range:
|
||||
A_path = self.A_paths[seq_idx][self.frame_idx + i]
|
||||
Ai = self.get_image(A_path, transform_scaleA, is_label=self.A_is_label)
|
||||
self.A = concat_frame(self.A, Ai, tG)
|
||||
|
||||
if self.use_real:
|
||||
B_path = self.B_paths[self.seq_idx][self.frame_idx + i]
|
||||
Bi = self.get_image(B_path, transform_scaleB)
|
||||
B = Bi if i == 0 else torch.cat([B, Bi], dim=0)
|
||||
B_path = self.B_paths[seq_idx][self.frame_idx + i]
|
||||
Bi = self.get_image(B_path, transform_scaleB)
|
||||
self.B = concat_frame(self.B, Bi, tG)
|
||||
else:
|
||||
self.B = 0
|
||||
|
||||
if self.opt.use_instance:
|
||||
I_path = self.I_paths[self.seq_idx][self.frame_idx + i]
|
||||
Ii = self.get_image(I_path, transform_scaleA) * 255.0
|
||||
inst = Ii if i == 0 else torch.cat([inst, Ii], dim=0)
|
||||
I_path = self.I_paths[seq_idx][self.frame_idx + i]
|
||||
Ii = self.get_image(I_path, transform_scaleA) * 255.0
|
||||
self.I = concat_frame(self.I, Ii, tG)
|
||||
else:
|
||||
self.I = 0
|
||||
|
||||
self.frame_idx += 1
|
||||
return_list = {'A': A, 'B': B, 'inst': inst, 'A_paths': A_path, 'change_seq': change_seq}
|
||||
self.frame_idx += 1
|
||||
return_list = {'A': self.A, 'B': self.B, 'inst': self.I, 'A_path': A_path, 'change_seq': self.change_seq}
|
||||
return return_list
|
||||
|
||||
def get_image(self, A_path, transform_scaleA):
|
||||
def get_image(self, A_path, transform_scaleA, is_label=False):
|
||||
A_img = Image.open(A_path)
|
||||
A_scaled = transform_scaleA(A_img)
|
||||
A_scaled = transform_scaleA(A_img)
|
||||
if is_label:
|
||||
A_scaled *= 255.0
|
||||
return A_scaled
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@ -66,8 +66,8 @@ class BaseModel(torch.nn.Module):
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
if not os.path.isfile(save_path):
|
||||
print('%s not exists yet!' % save_path)
|
||||
#if 'G' in network_label:
|
||||
# raise('Generator must exist!')
|
||||
if 'G0' in network_label:
|
||||
raise('Generator must exist!')
|
||||
else:
|
||||
#network.load_state_dict(torch.load(save_path))
|
||||
try:
|
||||
|
||||
@ -39,11 +39,15 @@ def define_G(input_nc, output_nc, prev_output_nc, ngf, which_model_netG, n_downs
|
||||
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsampling, opt.n_blocks, norm_layer)
|
||||
elif which_model_netG == 'local':
|
||||
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsampling, opt.n_blocks, opt.n_local_enhancers, opt.n_blocks_local, norm_layer)
|
||||
elif which_model_netG == 'global_with_features':
|
||||
netG = Global_with_z(input_nc, output_nc, opt.feat_num, ngf, n_downsampling, opt.n_blocks, norm_layer)
|
||||
elif which_model_netG == 'local_with_features':
|
||||
netG = Local_with_z(input_nc, output_nc, opt.feat_num, ngf, n_downsampling, opt.n_blocks, opt.n_local_enhancers, opt.n_blocks_local, norm_layer)
|
||||
|
||||
elif which_model_netG == 'composite':
|
||||
netG = CompositeGenerator(input_nc, output_nc, prev_output_nc, ngf, n_downsampling, opt.n_blocks, opt.fg, norm_layer)
|
||||
netG = CompositeGenerator(input_nc, output_nc, prev_output_nc, ngf, n_downsampling, opt.n_blocks, opt.fg, opt.no_flow, norm_layer)
|
||||
elif which_model_netG == 'compositeLocal':
|
||||
netG = CompositeLocalGenerator(input_nc, output_nc, prev_output_nc, ngf, n_downsampling, opt.n_blocks_local, opt.fg,
|
||||
netG = CompositeLocalGenerator(input_nc, output_nc, prev_output_nc, ngf, n_downsampling, opt.n_blocks_local, opt.fg, opt.no_flow,
|
||||
norm_layer, scale=scale)
|
||||
elif which_model_netG == 'encoder':
|
||||
netG = Encoder(input_nc, output_nc, ngf, n_downsampling, norm_layer)
|
||||
@ -78,13 +82,14 @@ def print_network(net):
|
||||
# Classes
|
||||
##############################################################################
|
||||
class CompositeGenerator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, prev_output_nc, ngf, n_downsampling, n_blocks, use_fg_model=False,
|
||||
def __init__(self, input_nc, output_nc, prev_output_nc, ngf, n_downsampling, n_blocks, use_fg_model=False, no_flow=False,
|
||||
norm_layer=nn.BatchNorm2d, padding_type='reflect'):
|
||||
assert(n_blocks >= 0)
|
||||
super(CompositeGenerator, self).__init__()
|
||||
self.resample = Resample2d()
|
||||
self.n_downsampling = n_downsampling
|
||||
self.use_fg_model = use_fg_model
|
||||
self.no_flow = no_flow
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
if use_fg_model:
|
||||
@ -128,18 +133,21 @@ class CompositeGenerator(nn.Module):
|
||||
model_res_img = []
|
||||
for i in range(n_blocks//2):
|
||||
model_res_img += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
||||
model_res_flow = copy.deepcopy(model_res_img)
|
||||
if not no_flow:
|
||||
model_res_flow = copy.deepcopy(model_res_img)
|
||||
|
||||
### upsample
|
||||
model_up_img = []
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model_up_img += [nn.ConvTranspose2d(ngf*mult, ngf*mult//2, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
norm_layer(ngf*mult//2), activation]
|
||||
model_up_flow = copy.deepcopy(model_up_img)
|
||||
norm_layer(ngf*mult//2), activation]
|
||||
model_final_img = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
||||
model_final_flow = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 2, kernel_size=7, padding=0)]
|
||||
model_final_w = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 1, kernel_size=7, padding=0), nn.Sigmoid()]
|
||||
|
||||
if not no_flow:
|
||||
model_up_flow = copy.deepcopy(model_up_img)
|
||||
model_final_flow = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 2, kernel_size=7, padding=0)]
|
||||
model_final_w = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 1, kernel_size=7, padding=0), nn.Sigmoid()]
|
||||
|
||||
if use_fg_model:
|
||||
self.indv_down = nn.Sequential(*indv_down)
|
||||
@ -150,25 +158,29 @@ class CompositeGenerator(nn.Module):
|
||||
self.model_down_seg = nn.Sequential(*model_down_seg)
|
||||
self.model_down_img = nn.Sequential(*model_down_img)
|
||||
self.model_res_img = nn.Sequential(*model_res_img)
|
||||
self.model_res_flow = nn.Sequential(*model_res_flow)
|
||||
self.model_up_img = nn.Sequential(*model_up_img)
|
||||
self.model_up_flow = nn.Sequential(*model_up_flow)
|
||||
self.model_final_img = nn.Sequential(*model_final_img)
|
||||
self.model_final_flow = nn.Sequential(*model_final_flow)
|
||||
self.model_final_w = nn.Sequential(*model_final_w)
|
||||
|
||||
if not no_flow:
|
||||
self.model_res_flow = nn.Sequential(*model_res_flow)
|
||||
self.model_up_flow = nn.Sequential(*model_up_flow)
|
||||
self.model_final_flow = nn.Sequential(*model_final_flow)
|
||||
self.model_final_w = nn.Sequential(*model_final_w)
|
||||
|
||||
def forward(self, input, img_prev, mask, img_feat_coarse, flow_feat_coarse, img_fg_feat_coarse, use_raw_only):
|
||||
downsample = self.model_down_seg(input) + self.model_down_img(img_prev)
|
||||
img_feat = self.model_up_img(self.model_res_img(downsample))
|
||||
res_flow = self.model_res_flow(downsample)
|
||||
flow_feat = self.model_up_flow(res_flow)
|
||||
|
||||
img_raw = self.model_final_img(img_feat)
|
||||
flow = self.model_final_flow(flow_feat) * 20
|
||||
weight = self.model_final_w(flow_feat)
|
||||
|
||||
flow = weight = flow_feat = None
|
||||
if not self.no_flow:
|
||||
res_flow = self.model_res_flow(downsample)
|
||||
flow_feat = self.model_up_flow(res_flow)
|
||||
flow = self.model_final_flow(flow_feat) * 20
|
||||
weight = self.model_final_w(flow_feat)
|
||||
|
||||
gpu_id = img_feat.get_device()
|
||||
if use_raw_only:
|
||||
if use_raw_only or self.no_flow:
|
||||
img_final = img_raw
|
||||
else:
|
||||
img_warp = self.resample(img_prev[:,-3:,...].cuda(gpu_id), flow).cuda(gpu_id)
|
||||
@ -187,11 +199,12 @@ class CompositeGenerator(nn.Module):
|
||||
return img_final, flow, weight, img_raw, img_feat, flow_feat, img_fg_feat
|
||||
|
||||
class CompositeLocalGenerator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, prev_output_nc, ngf, n_downsampling, n_blocks_local, use_fg_model=False,
|
||||
def __init__(self, input_nc, output_nc, prev_output_nc, ngf, n_downsampling, n_blocks_local, use_fg_model=False, no_flow=False,
|
||||
norm_layer=nn.BatchNorm2d, padding_type='reflect', scale=1):
|
||||
super(CompositeLocalGenerator, self).__init__()
|
||||
self.resample = Resample2d()
|
||||
self.use_fg_model = use_fg_model
|
||||
self.no_flow = no_flow
|
||||
self.scale = scale
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
@ -218,20 +231,19 @@ class CompositeLocalGenerator(nn.Module):
|
||||
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1), norm_layer(ngf*2), activation]
|
||||
|
||||
### resnet blocks
|
||||
model_up_img = []
|
||||
model_up_flow = []
|
||||
model_up_img = []
|
||||
for i in range(n_blocks_local):
|
||||
model_up_img += [ResnetBlock(ngf*2, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
||||
model_up_flow += [ResnetBlock(ngf*2, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
||||
model_up_img += [ResnetBlock(ngf*2, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
||||
|
||||
### upsample
|
||||
up = [nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1), norm_layer(ngf), activation]
|
||||
model_up_img += up
|
||||
model_up_flow += copy.deepcopy(up)
|
||||
|
||||
model_final_img = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
||||
model_final_flow = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 2, kernel_size=7, padding=0)]
|
||||
model_final_w = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 1, kernel_size=7, padding=0), nn.Sigmoid()]
|
||||
|
||||
if not no_flow:
|
||||
model_up_flow = copy.deepcopy(model_up_img)
|
||||
model_final_flow = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 2, kernel_size=7, padding=0)]
|
||||
model_final_w = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 1, kernel_size=7, padding=0), nn.Sigmoid()]
|
||||
|
||||
if use_fg_model:
|
||||
self.indv_down = nn.Sequential(*indv_down)
|
||||
@ -241,10 +253,12 @@ class CompositeLocalGenerator(nn.Module):
|
||||
self.model_down_seg = nn.Sequential(*model_down_seg)
|
||||
self.model_down_img = nn.Sequential(*model_down_img)
|
||||
self.model_up_img = nn.Sequential(*model_up_img)
|
||||
self.model_up_flow = nn.Sequential(*model_up_flow)
|
||||
self.model_final_img = nn.Sequential(*model_final_img)
|
||||
self.model_final_flow = nn.Sequential(*model_final_flow)
|
||||
self.model_final_w = nn.Sequential(*model_final_w)
|
||||
|
||||
if not no_flow:
|
||||
self.model_up_flow = nn.Sequential(*model_up_flow)
|
||||
self.model_final_flow = nn.Sequential(*model_final_flow)
|
||||
self.model_final_w = nn.Sequential(*model_final_w)
|
||||
|
||||
def forward(self, input, img_prev, mask, img_feat_coarse, flow_feat_coarse, img_fg_feat_coarse, use_raw_only):
|
||||
flow_multiplier = 20 * (2 ** self.scale)
|
||||
@ -252,13 +266,15 @@ class CompositeLocalGenerator(nn.Module):
|
||||
img_feat = self.model_up_img(down_img + img_feat_coarse)
|
||||
img_raw = self.model_final_img(img_feat)
|
||||
|
||||
down_flow = down_img
|
||||
flow_feat = self.model_up_flow(down_flow + flow_feat_coarse)
|
||||
flow = self.model_final_flow(flow_feat) * flow_multiplier
|
||||
weight = self.model_final_w(flow_feat)
|
||||
flow = weight = flow_feat = None
|
||||
if not self.no_flow:
|
||||
down_flow = down_img
|
||||
flow_feat = self.model_up_flow(down_flow + flow_feat_coarse)
|
||||
flow = self.model_final_flow(flow_feat) * flow_multiplier
|
||||
weight = self.model_final_w(flow_feat)
|
||||
|
||||
gpu_id = img_feat.get_device()
|
||||
if use_raw_only:
|
||||
if use_raw_only or self.no_flow:
|
||||
img_final = img_raw
|
||||
else:
|
||||
img_warp = self.resample(img_prev[:,-3:,...].cuda(gpu_id), flow).cuda(gpu_id)
|
||||
@ -303,7 +319,7 @@ class GlobalGenerator(nn.Module):
|
||||
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, img_feat_coarse=None, feat=None):
|
||||
def forward(self, input, feat=None):
|
||||
if feat is not None:
|
||||
input = torch.cat([input, feat], dim=1)
|
||||
output = self.model(input)
|
||||
@ -369,6 +385,138 @@ class LocalEnhancer(nn.Module):
|
||||
output_prev = model_upsample(model_downsample(input_i) + output_prev)
|
||||
return output_prev
|
||||
|
||||
class Global_with_z(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, nz, ngf=64, n_downsample_G=3, n_blocks=9,
|
||||
norm_layer=nn.BatchNorm2d, padding_type='reflect'):
|
||||
super(Global_with_z, self).__init__()
|
||||
self.n_downsample_G = n_downsample_G
|
||||
max_ngf = 1024
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
# downsample model
|
||||
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc + nz, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
||||
for i in range(n_downsample_G):
|
||||
mult = 2 ** i
|
||||
model_downsample += [nn.Conv2d(min(ngf * mult, max_ngf), min(ngf * mult * 2, max_ngf), kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(min(ngf * mult * 2, max_ngf)), activation]
|
||||
|
||||
# internal model
|
||||
model_resnet = []
|
||||
mult = 2 ** n_downsample_G
|
||||
for i in range(n_blocks):
|
||||
model_resnet += [ResnetBlock(min(ngf*mult, max_ngf) + nz, padding_type=padding_type, norm_layer=norm_layer)]
|
||||
|
||||
# upsample model
|
||||
model_upsample = []
|
||||
for i in range(n_downsample_G):
|
||||
mult = 2 ** (n_downsample_G - i)
|
||||
input_ngf = min(ngf * mult, max_ngf)
|
||||
if i == 0:
|
||||
input_ngf += nz * 2
|
||||
model_upsample += [nn.ConvTranspose2d(input_ngf, min((ngf * mult // 2), max_ngf), kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1), norm_layer(min((ngf * mult // 2), max_ngf)), activation]
|
||||
|
||||
model_upsample_conv = [nn.ReflectionPad2d(3), nn.Conv2d(ngf + nz, output_nc, kernel_size=7), nn.Tanh()]
|
||||
|
||||
self.model_downsample = nn.Sequential(*model_downsample)
|
||||
self.model_resnet = nn.Sequential(*model_resnet)
|
||||
self.model_upsample = nn.Sequential(*model_upsample)
|
||||
self.model_upsample_conv = nn.Sequential(*model_upsample_conv)
|
||||
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, x, z):
|
||||
z_downsample = z
|
||||
for i in range(self.n_downsample_G):
|
||||
z_downsample = self.downsample(z_downsample)
|
||||
downsample = self.model_downsample(torch.cat([x, z], dim=1))
|
||||
resnet = self.model_resnet(torch.cat([downsample, z_downsample], dim=1))
|
||||
upsample = self.model_upsample(torch.cat([resnet, z_downsample], dim=1))
|
||||
return self.model_upsample_conv(torch.cat([upsample, z], dim=1))
|
||||
|
||||
class Local_with_z(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, nz, ngf=32, n_downsample_global=3, n_blocks_global=9,
|
||||
n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
|
||||
super(Local_with_z, self).__init__()
|
||||
self.n_local_enhancers = n_local_enhancers
|
||||
self.n_downsample_global = n_downsample_global
|
||||
|
||||
###### global generator model #####
|
||||
ngf_global = ngf * (2**n_local_enhancers)
|
||||
model_global = Global_with_z(input_nc, output_nc, nz, ngf_global, n_downsample_global, n_blocks_global, norm_layer)
|
||||
self.model_downsample = model_global.model_downsample
|
||||
self.model_resnet = model_global.model_resnet
|
||||
self.model_upsample = model_global.model_upsample
|
||||
|
||||
###### local enhancer layers #####
|
||||
for n in range(1, n_local_enhancers+1):
|
||||
### downsample
|
||||
ngf_global = ngf * (2**(n_local_enhancers-n))
|
||||
if n == n_local_enhancers:
|
||||
input_nc += nz
|
||||
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7),
|
||||
norm_layer(ngf_global), nn.ReLU(True),
|
||||
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(ngf_global * 2), nn.ReLU(True)]
|
||||
### residual blocks
|
||||
model_upsample = []
|
||||
input_ngf = ngf_global * 2
|
||||
if n == 1:
|
||||
input_ngf += nz
|
||||
for i in range(n_blocks_local):
|
||||
model_upsample += [ResnetBlock(input_ngf, padding_type=padding_type, norm_layer=norm_layer)]
|
||||
### upsample
|
||||
model_upsample += [nn.ConvTranspose2d(input_ngf, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
norm_layer(ngf_global), nn.ReLU(True)]
|
||||
|
||||
setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
|
||||
setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
|
||||
|
||||
### final convolution
|
||||
model_final = [nn.ReflectionPad2d(3), nn.Conv2d(ngf + nz, output_nc, kernel_size=7), nn.Tanh()]
|
||||
self.model_final = nn.Sequential(*model_final)
|
||||
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, input, z):
|
||||
### create input pyramid
|
||||
input_downsampled = [input]
|
||||
for i in range(self.n_local_enhancers):
|
||||
input_downsampled.append(self.downsample(input_downsampled[-1]))
|
||||
|
||||
### create downsampled z
|
||||
z_downsampled_local = z
|
||||
for i in range(self.n_local_enhancers):
|
||||
z_downsampled_local = self.downsample(z_downsampled_local)
|
||||
z_downsampled_global = z_downsampled_local
|
||||
for i in range(self.n_downsample_global):
|
||||
z_downsampled_global = self.downsample(z_downsampled_global)
|
||||
|
||||
### output at coarest level
|
||||
x = input_downsampled[-1]
|
||||
global_downsample = self.model_downsample(torch.cat([x, z_downsampled_local], dim=1))
|
||||
global_resnet = self.model_resnet(torch.cat([global_downsample, z_downsampled_global], dim=1))
|
||||
global_upsample = self.model_upsample(torch.cat([global_resnet, z_downsampled_global], dim=1))
|
||||
|
||||
### build up one layer at a time
|
||||
output_prev = global_upsample
|
||||
for n_local_enhancers in range(1, self.n_local_enhancers+1):
|
||||
# fetch models
|
||||
model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
|
||||
model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
|
||||
# get input image
|
||||
input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
|
||||
if n_local_enhancers == self.n_local_enhancers:
|
||||
input_i = torch.cat([input_i, z], dim=1)
|
||||
# combine features from different resolutions
|
||||
combined_input = model_downsample(input_i) + output_prev
|
||||
if n_local_enhancers == 1:
|
||||
combined_input = torch.cat([combined_input, z_downsampled_local], dim=1)
|
||||
# upsample features
|
||||
output_prev = model_upsample(combined_input)
|
||||
|
||||
# final convolution
|
||||
output = self.model_final(torch.cat([output_prev, z], dim=1))
|
||||
return output
|
||||
|
||||
# Define a resnet block
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
||||
|
||||
@ -35,6 +35,10 @@ class Vid2VidModelD(BaseModel):
|
||||
|
||||
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm,
|
||||
opt.num_D, not opt.no_ganFeat, gpu_ids=self.gpu_ids)
|
||||
|
||||
if opt.add_face_disc:
|
||||
self.netD_f = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm,
|
||||
opt.num_D - 2, not opt.no_ganFeat, gpu_ids=self.gpu_ids)
|
||||
|
||||
# temporal discriminator
|
||||
netD_input_nc = opt.output_nc * opt.n_frames_D + 2 * (opt.n_frames_D-1)
|
||||
@ -50,9 +54,11 @@ class Vid2VidModelD(BaseModel):
|
||||
|
||||
# load networks
|
||||
if opt.continue_train or opt.load_pretrain:
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, opt.load_pretrain)
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, opt.load_pretrain)
|
||||
for s in range(opt.n_scales_temporal):
|
||||
self.load_network(getattr(self, 'netD_T'+str(s)), 'D_T'+str(s), opt.which_epoch, opt.load_pretrain)
|
||||
self.load_network(getattr(self, 'netD_T'+str(s)), 'D_T'+str(s), opt.which_epoch, opt.load_pretrain)
|
||||
if opt.add_face_disc:
|
||||
self.load_network(self.netD_f, 'D_f', opt.which_epoch, opt.load_pretrain)
|
||||
|
||||
# set loss functions and optimizers
|
||||
self.old_lr = opt.lr
|
||||
@ -68,9 +74,13 @@ class Vid2VidModelD(BaseModel):
|
||||
'D_real', 'D_fake',
|
||||
'G_Warp', 'F_Flow', 'F_Warp', 'W']
|
||||
self.loss_names_T = ['G_T_GAN', 'G_T_GAN_Feat', 'D_T_real', 'D_T_fake', 'G_T_Warp']
|
||||
if opt.add_face_disc:
|
||||
self.loss_names += ['G_f_GAN', 'G_f_GAN_Feat', 'D_f_real', 'D_f_fake']
|
||||
|
||||
# initialize optimizers D and D_T
|
||||
params = list(self.netD.parameters())
|
||||
params = list(self.netD.parameters())
|
||||
if opt.add_face_disc:
|
||||
params += list(self.netD_f.parameters())
|
||||
if opt.TTUR:
|
||||
beta1, beta2 = 0, 0.9
|
||||
lr = opt.lr * 2
|
||||
@ -83,11 +93,8 @@ class Vid2VidModelD(BaseModel):
|
||||
params = list(getattr(self, 'netD_T'+str(s)).parameters())
|
||||
optimizer_D_T = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
setattr(self, 'optimizer_D_T'+str(s), optimizer_D_T)
|
||||
|
||||
self.downsample = torch.nn.AvgPool2d(2, stride=2)
|
||||
|
||||
def compute_loss_D(self, real_A, real_B, fake_B):
|
||||
netD = self.netD
|
||||
def compute_loss_D(self, netD, real_A, real_B, fake_B):
|
||||
real_AB = torch.cat((real_A, real_B), dim=1)
|
||||
fake_AB = torch.cat((real_A, fake_B), dim=1)
|
||||
pred_real = netD.forward(real_AB)
|
||||
@ -156,23 +163,26 @@ class Vid2VidModelD(BaseModel):
|
||||
real_B, fake_B, fake_B_raw, real_A, real_B_prev, fake_B_prev, flow, weight, flow_ref, conf_ref = tensors_list
|
||||
_, _, self.height, self.width = real_B.size()
|
||||
|
||||
################### Flow loss #################
|
||||
# similar to flownet flow
|
||||
loss_F_Flow = self.criterionFlow(flow, flow_ref, conf_ref) * lambda_F / (2 ** (scale_S-1))
|
||||
# warped prev image should be close to current image
|
||||
real_B_warp = self.resample(real_B_prev, flow)
|
||||
loss_F_Warp = self.criterionFlow(real_B_warp, real_B, conf_ref) * lambda_T
|
||||
|
||||
################## weight loss ##################
|
||||
loss_W = torch.zeros_like(weight)
|
||||
if self.opt.no_first_img:
|
||||
dummy0 = torch.zeros_like(weight)
|
||||
loss_W = self.criterionFeat(weight, dummy0)
|
||||
|
||||
################### Flow loss #################
|
||||
if flow is not None:
|
||||
# similar to flownet flow
|
||||
loss_F_Flow = self.criterionFlow(flow, flow_ref, conf_ref) * lambda_F / (2 ** (scale_S-1))
|
||||
# warped prev image should be close to current image
|
||||
real_B_warp = self.resample(real_B_prev, flow)
|
||||
loss_F_Warp = self.criterionFlow(real_B_warp, real_B, conf_ref) * lambda_T
|
||||
|
||||
################## weight loss ##################
|
||||
loss_W = torch.zeros_like(weight)
|
||||
if self.opt.no_first_img:
|
||||
dummy0 = torch.zeros_like(weight)
|
||||
loss_W = self.criterionFlow(weight, dummy0, conf_ref)
|
||||
else:
|
||||
loss_F_Flow = loss_F_Warp = loss_W = torch.zeros_like(conf_ref)
|
||||
|
||||
#################### fake_B loss ####################
|
||||
### VGG + GAN loss
|
||||
loss_G_VGG = (self.criterionVGG(fake_B, real_B) * lambda_feat) if not self.opt.no_vgg else torch.zeros_like(loss_W)
|
||||
loss_D_real, loss_D_fake, loss_G_GAN, loss_G_GAN_Feat = self.compute_loss_D(real_A, real_B, fake_B)
|
||||
loss_D_real, loss_D_fake, loss_G_GAN, loss_G_GAN_Feat = self.compute_loss_D(self.netD, real_A, real_B, fake_B)
|
||||
### Warp loss
|
||||
fake_B_warp_ref = self.resample(fake_B_prev, flow_ref)
|
||||
loss_G_Warp = self.criterionWarp(fake_B, fake_B_warp_ref.detach(), conf_ref) * lambda_T
|
||||
@ -180,20 +190,52 @@ class Vid2VidModelD(BaseModel):
|
||||
if fake_B_raw is not None:
|
||||
if not self.opt.no_vgg:
|
||||
loss_G_VGG += self.criterionVGG(fake_B_raw, real_B) * lambda_feat
|
||||
l_D_real, l_D_fake, l_G_GAN, l_G_GAN_Feat = self.compute_loss_D(real_A, real_B, fake_B_raw)
|
||||
l_D_real, l_D_fake, l_G_GAN, l_G_GAN_Feat = self.compute_loss_D(self.netD, real_A, real_B, fake_B_raw)
|
||||
loss_G_GAN += l_G_GAN; loss_G_GAN_Feat += l_G_GAN_Feat
|
||||
loss_D_real += l_D_real; loss_D_fake += l_D_fake
|
||||
loss_D_real += l_D_real; loss_D_fake += l_D_fake
|
||||
|
||||
if self.opt.add_face_disc:
|
||||
face_weight = 2
|
||||
ys, ye, xs, xe = self.get_face_region(real_A)
|
||||
if ys is not None:
|
||||
loss_D_f_real, loss_D_f_fake, loss_G_f_GAN, loss_G_f_GAN_Feat = self.compute_loss_D(self.netD_f,
|
||||
real_A[:,:,ys:ye,xs:xe], real_B[:,:,ys:ye,xs:xe], fake_B[:,:,ys:ye,xs:xe])
|
||||
loss_G_f_GAN *= face_weight
|
||||
loss_G_f_GAN_Feat *= face_weight
|
||||
else:
|
||||
loss_D_f_real = loss_D_f_fake = loss_G_f_GAN = loss_G_f_GAN_Feat = torch.zeros_like(loss_D_real)
|
||||
|
||||
loss_list = [loss_G_VGG, loss_G_GAN, loss_G_GAN_Feat,
|
||||
loss_D_real, loss_D_fake,
|
||||
loss_G_Warp, loss_F_Flow, loss_F_Warp, loss_W]
|
||||
loss_G_Warp, loss_F_Flow, loss_F_Warp, loss_W]
|
||||
if self.opt.add_face_disc:
|
||||
loss_list += [loss_G_f_GAN, loss_G_f_GAN_Feat, loss_D_f_real, loss_D_f_fake]
|
||||
loss_list = [loss.unsqueeze(0) for loss in loss_list]
|
||||
return loss_list
|
||||
|
||||
def get_face_region(self, real_A):
|
||||
_, _, h, w = real_A.size()
|
||||
if not self.opt.openpose_only:
|
||||
face = (real_A[:,2] > 0.9).nonzero()
|
||||
else:
|
||||
face = (((real_A[:,0] == 0.6) | (real_A[:,0] == 0.2)) & (real_A[:,1] == 0) & (real_A[:,2] == 0.6)).nonzero()
|
||||
if face.size()[0]:
|
||||
y, x = face[:,1], face[:,2]
|
||||
ys, ye, xs, xe = y.min().item(), y.max().item(), x.min().item(), x.max().item()
|
||||
yc, ylen = int(ys+ye)//2, self.opt.fineSize//32*8
|
||||
xc, xlen = int(xs+xe)//2, self.opt.fineSize//32*8
|
||||
yc = max(ylen//2, min(h-1 - ylen//2, yc))
|
||||
xc = max(xlen//2, min(w-1 - xlen//2, xc))
|
||||
ys, ye, xs, xe = yc - ylen//2, yc + ylen//2, xc - xlen//2, xc + xlen//2
|
||||
return ys, ye, xs, xe
|
||||
return None, None, None, None
|
||||
|
||||
def save(self, label):
|
||||
self.save_network(self.netD, 'D', label, self.gpu_ids)
|
||||
for s in range(self.opt.n_scales_temporal):
|
||||
self.save_network(getattr(self, 'netD_T'+str(s)), 'D_T'+str(s), label, self.gpu_ids)
|
||||
if self.opt.add_face_disc:
|
||||
self.save_network(self.netD_f, 'D_f', label, self.gpu_ids)
|
||||
|
||||
def update_learning_rate(self, epoch):
|
||||
lr = self.opt.lr * (1 - (epoch - self.opt.niter) / self.opt.niter_decay)
|
||||
|
||||
@ -25,13 +25,15 @@ class Vid2VidModelG(BaseModel):
|
||||
# define net G
|
||||
self.n_scales = opt.n_scales_spatial
|
||||
self.use_single_G = opt.use_single_G
|
||||
self.split_gpus = self.opt.n_gpus_gen > self.opt.batchSize
|
||||
self.split_gpus = (self.opt.n_gpus_gen < len(self.opt.gpu_ids)) and (self.opt.batchSize == 1)
|
||||
|
||||
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
|
||||
netG_input_nc = input_nc * opt.n_frames_G
|
||||
if opt.use_instance:
|
||||
netG_input_nc += opt.n_frames_G
|
||||
prev_output_nc = (opt.n_frames_G - 1) * opt.output_nc
|
||||
prev_output_nc = (opt.n_frames_G - 1) * opt.output_nc
|
||||
if opt.openpose_only:
|
||||
opt.no_flow = True
|
||||
|
||||
self.netG0 = networks.define_G(netG_input_nc, opt.output_nc, prev_output_nc, opt.ngf, opt.netG,
|
||||
opt.n_downsample_G, opt.norm, 0, self.gpu_ids, opt)
|
||||
@ -97,24 +99,27 @@ class Vid2VidModelG(BaseModel):
|
||||
if self.opt.use_instance:
|
||||
inst_map = inst_map.data.cuda()
|
||||
edge_map = Variable(self.get_edges(inst_map))
|
||||
input_map = torch.cat([input_map, edge_map], dim=2)
|
||||
input_map = torch.cat([input_map, edge_map], dim=2)
|
||||
|
||||
pool_map = None
|
||||
if self.opt.dataset_mode == 'face':
|
||||
pool_map = inst_map.data.cuda()
|
||||
|
||||
# real images for training
|
||||
if real_image is not None:
|
||||
real_image = Variable(real_image.data.cuda())
|
||||
|
||||
return input_map, real_image
|
||||
return input_map, real_image, pool_map
|
||||
|
||||
def forward(self, input_A, input_B, inst_A, fake_B_prev):
|
||||
tG = self.opt.n_frames_G
|
||||
gpu_split_id = self.opt.n_gpus_gen + 1
|
||||
real_A_all, real_B_all = self.encode_input(input_A, input_B, inst_A)
|
||||
real_A_all, real_B_all, _ = self.encode_input(input_A, input_B, inst_A)
|
||||
|
||||
is_first_frame = fake_B_prev is None
|
||||
if is_first_frame: # at the beginning of a sequence; needs to generate the first frame
|
||||
fake_B_prev = self.generate_first_frame(real_A_all, real_B_all)
|
||||
|
||||
fake_Bs, fake_Bs_raw, flows, weights = None, None, None, None
|
||||
|
||||
netG = []
|
||||
for s in range(self.n_scales): # broadcast netG to all GPUs used for generator
|
||||
netG_s = getattr(self, 'netG'+str(s))
|
||||
@ -171,8 +176,9 @@ class Vid2VidModelG(BaseModel):
|
||||
|
||||
# if only training the finest scale, leave the coarser levels untouched
|
||||
if s != n_scales-1 and not finetune_all:
|
||||
fake_B, flow = fake_B.detach(), flow.detach()
|
||||
fake_B_feat, flow_feat = fake_B_feat.detach(), flow_feat.detach()
|
||||
fake_B, fake_B_feat = fake_B.detach(), fake_B_feat.detach()
|
||||
if flow is not None:
|
||||
flow, flow_feat = flow.detach(), flow_feat.detach()
|
||||
if fake_B_fg_feat is not None:
|
||||
fake_B_fg_feat = fake_B_fg_feat.detach()
|
||||
|
||||
@ -180,23 +186,24 @@ class Vid2VidModelG(BaseModel):
|
||||
fake_B_pyr[si] = self.concat([fake_B_pyr[si], fake_B.unsqueeze(1).cuda(dest_id)], dim=1)
|
||||
if s == n_scales-1:
|
||||
fake_Bs_raw = self.concat([fake_Bs_raw, fake_B_raw.unsqueeze(1).cuda(dest_id)], dim=1)
|
||||
flows = self.concat([flows, flow.unsqueeze(1).cuda(dest_id)], dim=1)
|
||||
weights = self.concat([weights, weight.unsqueeze(1).cuda(dest_id)], dim=1)
|
||||
if flow is not None:
|
||||
flows = self.concat([flows, flow.unsqueeze(1).cuda(dest_id)], dim=1)
|
||||
weights = self.concat([weights, weight.unsqueeze(1).cuda(dest_id)], dim=1)
|
||||
|
||||
return fake_B_pyr, fake_Bs_raw, flows, weights
|
||||
|
||||
def inference(self, input_A, input_B, inst_A):
|
||||
with torch.no_grad():
|
||||
real_A, real_B = self.encode_input(input_A, input_B, inst_A)
|
||||
real_A, real_B, pool_map = self.encode_input(input_A, input_B, inst_A)
|
||||
self.is_first_frame = not hasattr(self, 'fake_B_prev') or self.fake_B_prev is None
|
||||
if self.is_first_frame:
|
||||
self.fake_B_prev = self.generate_first_frame(real_A, real_B)
|
||||
self.fake_B_prev = self.generate_first_frame(real_A, real_B, pool_map)
|
||||
|
||||
real_A = self.build_pyr(real_A)
|
||||
self.fake_B_feat = self.flow_feat = self.fake_B_fg_feat = None
|
||||
for s in range(self.n_scales):
|
||||
fake_B = self.generate_frame_infer(real_A[self.n_scales-1-s], s)
|
||||
return fake_B, real_A[0][-1:]
|
||||
return fake_B, real_A[0][0, -1]
|
||||
|
||||
def generate_frame_infer(self, real_A, s):
|
||||
tG = self.opt.n_frames_G
|
||||
@ -205,7 +212,7 @@ class Vid2VidModelG(BaseModel):
|
||||
netG_s = getattr(self, 'netG'+str(s))
|
||||
|
||||
### prepare inputs
|
||||
real_As_reshaped = real_A[0,:tG].view(1, -1, h, w)
|
||||
real_As_reshaped = real_A[0,:tG].view(1, -1, h, w)
|
||||
fake_B_prevs_reshaped = self.fake_B_prev[si].view(1, -1, h, w)
|
||||
mask_F = self.compute_mask(real_A, tG-1)[0] if self.opt.fg else None
|
||||
use_raw_only = self.opt.no_first_img and self.is_first_frame
|
||||
@ -218,17 +225,19 @@ class Vid2VidModelG(BaseModel):
|
||||
self.fake_B_prev[si] = torch.cat([self.fake_B_prev[si][1:,...], fake_B])
|
||||
return fake_B
|
||||
|
||||
def generate_first_frame(self, real_A=None, real_B=None):
|
||||
def generate_first_frame(self, real_A, real_B, pool_map=None):
|
||||
tG = self.opt.n_frames_G
|
||||
if self.opt.no_first_img: # model also generates first frame
|
||||
fake_B_prev = Variable(self.Tensor(self.bs, tG-1, self.opt.output_nc, self.height, self.width).zero_())
|
||||
elif self.opt.isTrain or self.opt.use_real_img: # assume first frame is given
|
||||
fake_B_prev = real_B[:,:(tG-1),...]
|
||||
elif self.opt.use_single_G: # use another model (trained on single images) to generate first frame
|
||||
fake_B_prev = None
|
||||
real_A = real_A[:,:,:self.opt.label_nc,:,:]
|
||||
fake_B_prev = None
|
||||
if self.opt.use_instance:
|
||||
real_A = real_A[:,:,:self.opt.label_nc,:,:]
|
||||
for i in range(tG-1):
|
||||
fake_B = self.netG_i.forward(real_A[:,i]).unsqueeze(1)
|
||||
feat_map = self.get_face_features(real_B[:,i], pool_map[:,i]) if self.opt.dataset_mode == 'face' else None
|
||||
fake_B = self.netG_i.forward(real_A[:,i], feat_map).unsqueeze(1)
|
||||
fake_B_prev = self.concat([fake_B_prev, fake_B], dim=1)
|
||||
else:
|
||||
raise ValueError('Please specify the method for generating the first frame')
|
||||
@ -255,24 +264,74 @@ class Vid2VidModelG(BaseModel):
|
||||
def load_single_G(self): # load the model that generates the first frame
|
||||
opt = self.opt
|
||||
s = self.n_scales
|
||||
single_path = 'checkpoints/label2city_single/'
|
||||
net_name = 'latest_net_G.pth'
|
||||
input_nc = opt.label_nc
|
||||
|
||||
if opt.loadSize == 512:
|
||||
load_path = single_path + 'latest_net_G_512.pth'
|
||||
netG = networks.define_G(input_nc, opt.output_nc, 0, 64, 'global', 3, 'instance', 0, self.gpu_ids, opt)
|
||||
elif opt.loadSize == 1024:
|
||||
load_path = single_path + 'latest_net_G_1024.pth'
|
||||
netG = networks.define_G(input_nc, opt.output_nc, 0, 64, 'global', 4, 'instance', 0, self.gpu_ids, opt)
|
||||
elif opt.loadSize == 2048:
|
||||
load_path = single_path + 'latest_net_G_2048.pth'
|
||||
netG = networks.define_G(input_nc, opt.output_nc, 0, 32, 'local', 4, 'instance', 0, self.gpu_ids, opt)
|
||||
if 'City' in self.opt.dataroot:
|
||||
single_path = 'checkpoints/label2city_single/'
|
||||
if opt.loadSize == 512:
|
||||
load_path = single_path + 'latest_net_G_512.pth'
|
||||
netG = networks.define_G(35, 3, 0, 64, 'global', 3, 'instance', 0, self.gpu_ids, opt)
|
||||
elif opt.loadSize == 1024:
|
||||
load_path = single_path + 'latest_net_G_1024.pth'
|
||||
netG = networks.define_G(35, 3, 0, 64, 'global', 4, 'instance', 0, self.gpu_ids, opt)
|
||||
elif opt.loadSize == 2048:
|
||||
load_path = single_path + 'latest_net_G_2048.pth'
|
||||
netG = networks.define_G(35, 3, 0, 32, 'local', 4, 'instance', 0, self.gpu_ids, opt)
|
||||
else:
|
||||
raise ValueError('Single image generator does not exist')
|
||||
elif 'face' in self.opt.dataroot:
|
||||
single_path = 'checkpoints/edge2face_single/'
|
||||
load_path = single_path + 'latest_net_G.pth'
|
||||
opt.feat_num = 16
|
||||
netG = networks.define_G(15, 3, 0, 64, 'global_with_features', 3, 'instance', 0, self.gpu_ids, opt)
|
||||
encoder_path = single_path + 'latest_net_E.pth'
|
||||
self.netE = networks.define_G(3, 16, 0, 16, 'encoder', 4, 'instance', 0, self.gpu_ids)
|
||||
self.netE.load_state_dict(torch.load(encoder_path))
|
||||
else:
|
||||
raise ValueError('Single image generator does not exist')
|
||||
netG.load_state_dict(torch.load(load_path))
|
||||
return netG
|
||||
|
||||
def get_face_features(self, real_image, inst):
|
||||
feat_map = self.netE.forward(real_image, inst)
|
||||
#if self.opt.use_encoded_image:
|
||||
# return feat_map
|
||||
|
||||
load_name = 'checkpoints/edge2face_single/features.npy'
|
||||
features = np.load(load_name, encoding='latin1').item()
|
||||
inst_np = inst.cpu().numpy().astype(int)
|
||||
|
||||
# find nearest neighbor in the training dataset
|
||||
num_images = features[6].shape[0]
|
||||
feat_map = feat_map.data.cpu().numpy()
|
||||
feat_ori = torch.FloatTensor(7, self.opt.feat_num, 1) # feature map for test img (for each facial part)
|
||||
feat_ref = torch.FloatTensor(7, self.opt.feat_num, num_images) # feature map for training imgs
|
||||
for label in np.unique(inst_np):
|
||||
idx = (inst == int(label)).nonzero()
|
||||
for k in range(self.opt.feat_num):
|
||||
feat_ori[label,k] = float(feat_map[idx[0,0], idx[0,1] + k, idx[0,2], idx[0,3]])
|
||||
for m in range(num_images):
|
||||
feat_ref[label,k,m] = features[label][m,k]
|
||||
cluster_idx = self.dists_min(feat_ori.expand_as(feat_ref).cuda(), feat_ref.cuda(), num=1)
|
||||
|
||||
# construct new feature map from nearest neighbors
|
||||
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
|
||||
for label in np.unique(inst_np):
|
||||
feat = features[label][:,:-1]
|
||||
idx = (inst == int(label)).nonzero()
|
||||
for k in range(self.opt.feat_num):
|
||||
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[min(cluster_idx, feat.shape[0]-1), k]
|
||||
|
||||
return Variable(feat_map)
|
||||
|
||||
def dists_min(self, a, b, num=1):
|
||||
dists = torch.sum(torch.sum((a-b)*(a-b), dim=0), dim=0)
|
||||
if num == 1:
|
||||
val, idx = torch.min(dists, dim=0)
|
||||
#idx = [idx]
|
||||
else:
|
||||
val, idx = torch.sort(dists, dim=0)
|
||||
idx = idx[:num]
|
||||
return idx.cpu().numpy().astype(int)
|
||||
|
||||
def get_edges(self, t):
|
||||
edge = torch.cuda.ByteTensor(t.size()).zero_()
|
||||
edge[:,:,:,:,1:] = edge[:,:,:,:,1:] | (t[:,:,:,:,1:] != t[:,:,:,:,:-1])
|
||||
|
||||
@ -14,7 +14,7 @@ class BaseOptions():
|
||||
self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
|
||||
self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
|
||||
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
|
||||
self.parser.add_argument('--label_nc', type=int, default=35, help='number of labels')
|
||||
self.parser.add_argument('--label_nc', type=int, default=0, help='number of labels')
|
||||
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
|
||||
|
||||
# network arch
|
||||
@ -38,7 +38,7 @@ class BaseOptions():
|
||||
self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
|
||||
|
||||
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
self.parser.add_argument('--resize_or_crop', type=str, default='scaleWidth', help='scaling and cropping of images at load time [resize_and_crop|crop|scaledCrop|scaleWidth|scaleWidth_and_crop|scaleWidth_and_scaledCrop] etc')
|
||||
self.parser.add_argument('--resize_or_crop', type=str, default='scaleWidth', help='scaling and cropping of images at load time [resize_and_crop|crop|scaledCrop|scaleWidth|scaleWidth_and_crop|scaleWidth_and_scaledCrop|scaleHeight|scaleHeight_and_crop] etc')
|
||||
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
|
||||
|
||||
# more features as input
|
||||
@ -61,6 +61,18 @@ class BaseOptions():
|
||||
self.parser.add_argument('--use_single_G', action='store_true', help='if specified, use single frame generator for the first frame')
|
||||
self.parser.add_argument('--fg', action='store_true', help='if specified, use foreground-background seperation model')
|
||||
self.parser.add_argument('--fg_labels', type=str, default='26', help='label indices for foreground objects')
|
||||
self.parser.add_argument('--no_flow', action='store_true', help='if specified, do not use flow warping and directly synthesize frames')
|
||||
|
||||
# face specific
|
||||
self.parser.add_argument('--no_canny_edge', action='store_true', help='do *not* use canny edge as input')
|
||||
self.parser.add_argument('--no_dist_map', action='store_true', help='do *not* use distance transform map as input')
|
||||
|
||||
# pose specific
|
||||
self.parser.add_argument('--densepose_only', action='store_true', help='use only densepose as input')
|
||||
self.parser.add_argument('--openpose_only', action='store_true', help='use only openpose as input')
|
||||
self.parser.add_argument('--add_face_disc', action='store_true', help='add face discriminator')
|
||||
self.parser.add_argument('--remove_face_labels', action='store_true', help='remove face labels to better adapt to different face shapes')
|
||||
self.parser.add_argument('--random_drop_prob', type=float, default=0.2, help='the probability to randomly drop each pose segment during training')
|
||||
|
||||
# miscellaneous
|
||||
self.parser.add_argument('--load_pretrain', type=str, default='', help='if specified, load the pretrained model')
|
||||
|
||||
@ -10,5 +10,6 @@ class TestOptions(BaseOptions):
|
||||
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
||||
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
self.parser.add_argument('--how_many', type=int, default=300, help='how many test images to run')
|
||||
self.parser.add_argument('--use_real_img', action='store_true', help='use real image for first frame')
|
||||
self.parser.add_argument('--use_real_img', action='store_true', help='use real image for first frame')
|
||||
self.parser.add_argument('--start_frame', type=int, default=0, help='frame index to start inference on')
|
||||
self.isTrain = False
|
||||
|
||||
@ -33,9 +33,10 @@ class TrainOptions(BaseOptions):
|
||||
self.parser.add_argument('--n_frames_D', type=int, default=3, help='number of frames to feed into temporal discriminator')
|
||||
self.parser.add_argument('--n_scales_temporal', type=int, default=3, help='number of temporal scales in the temporal discriminator')
|
||||
self.parser.add_argument('--max_frames_per_gpu', type=int, default=1, help='max number of frames to load into one GPU at a time')
|
||||
self.parser.add_argument('--max_frames_backpropagate', type=int, default=1, help='max number of frames to backpropagate')
|
||||
self.parser.add_argument('--max_frames_backpropagate', type=int, default=1, help='max number of frames to backpropagate')
|
||||
self.parser.add_argument('--max_t_step', type=int, default=1, help='max spacing between neighboring sampled frames. If greater than 1, the network may randomly skip frames during training.')
|
||||
self.parser.add_argument('--n_frames_total', type=int, default=30, help='the overall number of frames in a sequence to train with')
|
||||
self.parser.add_argument('--niter_step', type=int, default=5, help='how many epochs do we change training batch size again')
|
||||
self.parser.add_argument('--niter_fix_global', type=int, default=0, help='if specified, only train the finest spatial layer for the given iterations')
|
||||
self.parser.add_argument('--niter_fix_global', type=int, default=0, help='if specified, only train the finest spatial layer for the given iterations')
|
||||
|
||||
self.isTrain = True
|
||||
|
||||
10
scripts/download_datasets.py
Executable file
10
scripts/download_datasets.py
Executable file
@ -0,0 +1,10 @@
|
||||
import os
|
||||
from download_gdrive import *
|
||||
|
||||
file_id = '1rPcbnanuApZeo2uc7h55OneBkbcFCnnf'
|
||||
chpt_path = './datasets/'
|
||||
if not os.path.isdir(chpt_path):
|
||||
os.makedirs(chpt_path)
|
||||
destination = os.path.join(chpt_path, 'datasets.zip')
|
||||
download_file_from_google_drive(file_id, destination)
|
||||
unzip_file(destination, chpt_path)
|
||||
10
scripts/face/download_models.py
Executable file
10
scripts/face/download_models.py
Executable file
@ -0,0 +1,10 @@
|
||||
import os
|
||||
from scripts.download_gdrive import *
|
||||
|
||||
file_id = '10LvNw-2lrh-6sPGkWbQDfHspkqz5AKxb'
|
||||
chpt_path = './checkpoints/'
|
||||
if not os.path.isdir(chpt_path):
|
||||
os.makedirs(chpt_path)
|
||||
destination = os.path.join(chpt_path, 'models_face.zip')
|
||||
download_file_from_google_drive(file_id, destination)
|
||||
unzip_file(destination, chpt_path)
|
||||
3
scripts/face/test_512.sh
Executable file
3
scripts/face/test_512.sh
Executable file
@ -0,0 +1,3 @@
|
||||
python test.py --name edge2face_512 \
|
||||
--dataroot datasets/face/ --dataset_mode face \
|
||||
--input_nc 15 --loadSize 512 --use_single_G
|
||||
3
scripts/face/test_g1_256.sh
Executable file
3
scripts/face/test_g1_256.sh
Executable file
@ -0,0 +1,3 @@
|
||||
python test.py --name edge2face_256_g1 \
|
||||
--dataroot datasets/face/ --dataset_mode face \
|
||||
--input_nc 15 --loadSize 256 --ngf 64 --use_single_G
|
||||
4
scripts/face/test_g1_512.sh
Executable file
4
scripts/face/test_g1_512.sh
Executable file
@ -0,0 +1,4 @@
|
||||
python test.py --name edge2face_512_g1 \
|
||||
--dataroot datasets/face/ --dataset_mode face \
|
||||
--n_scales_spatial 2 --input_nc 15 --loadSize 512 --ngf 64 \
|
||||
--use_single_G
|
||||
5
scripts/face/train_512.sh
Executable file
5
scripts/face/train_512.sh
Executable file
@ -0,0 +1,5 @@
|
||||
python train.py --name edge2face_512 \
|
||||
--dataroot datasets/face/ --dataset_mode face \
|
||||
--input_nc 15 --loadSize 512 --num_D 3 \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 \
|
||||
--n_frames_total 12
|
||||
4
scripts/face/train_g1_256.sh
Executable file
4
scripts/face/train_g1_256.sh
Executable file
@ -0,0 +1,4 @@
|
||||
python train.py --name edge2face_256_g1 \
|
||||
--dataroot datasets/face/ --dataset_mode face \
|
||||
--input_nc 15 --loadSize 256 --ngf 64 \
|
||||
--max_frames_per_gpu 6 --n_frames_total 12
|
||||
7
scripts/face/train_g1_512.sh
Executable file
7
scripts/face/train_g1_512.sh
Executable file
@ -0,0 +1,7 @@
|
||||
python train.py --name edge2face_512_g1 \
|
||||
--dataroot datasets/face/ --dataset_mode face \
|
||||
--n_scales_spatial 2 --num_D 3 \
|
||||
--input_nc 15 --loadSize 512 --ngf 64 \
|
||||
--n_frames_total 6 --niter_step 2 --niter_fix_global 5 \
|
||||
--niter 5 --niter_decay 5 \
|
||||
--lr 0.0001 --load_pretrain checkpoints/edge2face_256_g1
|
||||
4
scripts/pose/test_1024p.sh
Executable file
4
scripts/pose/test_1024p.sh
Executable file
@ -0,0 +1,4 @@
|
||||
python test.py --name pose2body_1024p \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 3 \
|
||||
--resize_or_crop scaleHeight --loadSize 1024 --no_first_img
|
||||
3
scripts/pose/test_256p.sh
Executable file
3
scripts/pose/test_256p.sh
Executable file
@ -0,0 +1,3 @@
|
||||
python test.py --name pose2body_256p \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --resize_or_crop scaleHeight --loadSize 256 --no_first_img
|
||||
4
scripts/pose/test_512p.sh
Executable file
4
scripts/pose/test_512p.sh
Executable file
@ -0,0 +1,4 @@
|
||||
python test.py --name pose2body_512p \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 2 \
|
||||
--resize_or_crop scaleHeight --loadSize 512 --no_first_img
|
||||
4
scripts/pose/test_g1_1024p.sh
Executable file
4
scripts/pose/test_g1_1024p.sh
Executable file
@ -0,0 +1,4 @@
|
||||
python test.py --name pose2body_1024p_g1 \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 3 --ngf 64 \
|
||||
--resize_or_crop scaleHeight --loadSize 1024 --no_first_img
|
||||
3
scripts/pose/test_g1_256p.sh
Executable file
3
scripts/pose/test_g1_256p.sh
Executable file
@ -0,0 +1,3 @@
|
||||
python test.py --name pose2body_256p_g1 \
|
||||
--dataroot datasets/pose --dataset_mode pose --ngf 64 \
|
||||
--input_nc 6 --resize_or_crop scaleHeight --loadSize 256 --no_first_img
|
||||
4
scripts/pose/test_g1_512p.sh
Executable file
4
scripts/pose/test_g1_512p.sh
Executable file
@ -0,0 +1,4 @@
|
||||
python test.py --name pose2body_512p_g1 \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 2 --ngf 64 \
|
||||
--resize_or_crop scaleHeight --loadSize 512 --no_first_img
|
||||
8
scripts/pose/train_1024p.sh
Executable file
8
scripts/pose/train_1024p.sh
Executable file
@ -0,0 +1,8 @@
|
||||
python train.py --name pose2body_1024p \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 3 --num_D 4 \
|
||||
--resize_or_crop randomScaleHeight_and_scaledCrop --loadSize 1536 --fineSize 1024 \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 4 \
|
||||
--no_first_img --n_frames_total 12 --max_t_step 4 --add_face_disc \
|
||||
--niter_fix_global 3 --niter 5 --niter_decay 5 \
|
||||
--lr 0.00005 --load_pretrain checkpoints/pose2body_512p
|
||||
6
scripts/pose/train_256p.sh
Executable file
6
scripts/pose/train_256p.sh
Executable file
@ -0,0 +1,6 @@
|
||||
python train.py --name pose2body_256p \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --num_D 2 \
|
||||
--resize_or_crop ScaleHeight_and_scaledCrop --loadSize 384 --fineSize 256 \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --batchSize 8 --max_frames_per_gpu 3 \
|
||||
--no_first_img --n_frames_total 12 --max_t_step 4
|
||||
8
scripts/pose/train_512p.sh
Executable file
8
scripts/pose/train_512p.sh
Executable file
@ -0,0 +1,8 @@
|
||||
python train.py --name pose2body_512p \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 2 --num_D 3 \
|
||||
--resize_or_crop randomScaleHeight_and_scaledCrop --loadSize 768 --fineSize 512 \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --batchSize 8 \
|
||||
--no_first_img --n_frames_total 12 --max_t_step 4 --add_face_disc \
|
||||
--niter_fix_global 3 --niter 5 --niter_decay 5 \
|
||||
--lr 0.0001 --load_pretrain checkpoints/pose2body_256p
|
||||
7
scripts/pose/train_g1_1024p.sh
Executable file
7
scripts/pose/train_g1_1024p.sh
Executable file
@ -0,0 +1,7 @@
|
||||
python train.py --name pose2body_1024p_g1 \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 3 --num_D 4 --ngf 64 --ndf 32 \
|
||||
--resize_or_crop randomScaleHeight_and_scaledCrop --loadSize 1536 --fineSize 1024 \
|
||||
--no_first_img --n_frames_total 12 --max_t_step 4 --add_face_disc \
|
||||
--niter_fix_global 3 --niter 5 --niter_decay 5 \
|
||||
--lr 0.00005 --load_pretrain checkpoints/pose2body_512p_g1
|
||||
5
scripts/pose/train_g1_256p.sh
Executable file
5
scripts/pose/train_g1_256p.sh
Executable file
@ -0,0 +1,5 @@
|
||||
python train.py --name pose2body_256p_g1 \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --ngf 64 --num_D 2 \
|
||||
--resize_or_crop randomScaleHeight_and_scaledCrop --loadSize 384 --fineSize 256 \
|
||||
--no_first_img --n_frames_total 12 --max_frames_per_gpu 4 --max_t_step 4
|
||||
7
scripts/pose/train_g1_512p.sh
Executable file
7
scripts/pose/train_g1_512p.sh
Executable file
@ -0,0 +1,7 @@
|
||||
python train.py --name pose2body_512p_g1 \
|
||||
--dataroot datasets/pose --dataset_mode pose \
|
||||
--input_nc 6 --n_scales_spatial 2 --ngf 64 --num_D 3 \
|
||||
--resize_or_crop randomScaleHeight_and_scaledCrop --loadSize 768 --fineSize 512 \
|
||||
--no_first_img --n_frames_total 12 --max_frames_per_gpu 2 --max_t_step 4 --add_face_disc \
|
||||
--niter_fix_global 3 --niter 5 --niter_decay 5 \
|
||||
--lr 0.0001 --load_pretrain checkpoints/pose2body_256p_g1
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from download_gdrive import *
|
||||
from scripts.download_gdrive import *
|
||||
|
||||
file_id = '1MKtImgtnGC28EPU7Nh9DfFpHW6okNVkl'
|
||||
chpt_path = './checkpoints/'
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from download_gdrive import *
|
||||
from scripts.download_gdrive import *
|
||||
|
||||
file_id = '1QoE1p3QikxNVbbTBWWRDtIspg-RcLE8y'
|
||||
chpt_path = './checkpoints/'
|
||||
@ -7,4 +7,4 @@ if not os.path.isdir(chpt_path):
|
||||
os.makedirs(chpt_path)
|
||||
destination = os.path.join(chpt_path, 'models_g1.zip')
|
||||
download_file_from_google_drive(file_id, destination)
|
||||
unzip_file(destination, chpt_path)
|
||||
unzip_file(destination, chpt_path)
|
||||
1
scripts/street/test_2048.sh
Executable file
1
scripts/street/test_2048.sh
Executable file
@ -0,0 +1 @@
|
||||
python test.py --name label2city_2048 --label_nc 35 --loadSize 2048 --n_scales_spatial 3 --use_instance --fg --use_single_G
|
||||
1
scripts/street/test_g1_1024.sh
Executable file
1
scripts/street/test_g1_1024.sh
Executable file
@ -0,0 +1 @@
|
||||
python test.py --name label2city_1024_g1 --label_nc 35 --loadSize 1024 --n_scales_spatial 3 --use_instance --fg --n_downsample_G 2 --use_single_G
|
||||
@ -1,5 +1,5 @@
|
||||
python train.py --name label2city_1024 \
|
||||
--loadSize 1024 --n_scales_spatial 2 --num_D 3 --use_instance --fg \
|
||||
--label_nc 35 --loadSize 1024 --n_scales_spatial 2 --num_D 3 --use_instance --fg \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 4 \
|
||||
--n_frames_total 4 --niter_step 2 \
|
||||
--niter_fix_global 10 --load_pretrain checkpoints/label2city_512 --lr 0.0001
|
||||
@ -1,5 +1,5 @@
|
||||
python train.py --name label2city_2048 \
|
||||
--loadSize 2048 --n_scales_spatial 3 --num_D 4 --use_instance --fg \
|
||||
--label_nc 35 --loadSize 2048 --n_scales_spatial 3 --num_D 4 --use_instance --fg \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 4 \
|
||||
--n_frames_total 4 --niter_step 1 \
|
||||
--niter 5 --niter_decay 5 \
|
||||
@ -1,5 +1,5 @@
|
||||
python train.py --name label2city_2048_crop \
|
||||
--loadSize 2048 --fineSize 1024 --resize_or_crop crop \
|
||||
--label_nc 35 --loadSize 2048 --fineSize 1024 --resize_or_crop crop \
|
||||
--n_scales_spatial 3 --num_D 4 --use_instance --fg \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 4 \
|
||||
--n_frames_total 4 --niter_step 1 \
|
||||
@ -1,4 +1,4 @@
|
||||
python train.py --name label2city_512 \
|
||||
--loadSize 512 --use_instance --fg \
|
||||
--label_nc 35 --loadSize 512 --use_instance --fg \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 \
|
||||
--n_frames_total 6 --max_frames_per_gpu 2
|
||||
@ -1,4 +1,4 @@
|
||||
python train.py --name label2city_512_bs \
|
||||
--loadSize 512 --use_instance --fg \
|
||||
--label_nc 35 --loadSize 512 --use_instance --fg \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 \
|
||||
--n_frames_total 6 --batchSize 6
|
||||
@ -1,4 +1,4 @@
|
||||
python train.py --name label2city_512_no_fg \
|
||||
--loadSize 512 --use_instance \
|
||||
--label_nc 35 --loadSize 512 --use_instance \
|
||||
--gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 \
|
||||
--n_frames_total 6 --max_frames_per_gpu 2
|
||||
@ -1,5 +1,5 @@
|
||||
python train.py --name label2city_1024_g1 \
|
||||
--loadSize 896 --n_scales_spatial 3 --n_frames_D 2 \
|
||||
--label_nc 35 --loadSize 896 --n_scales_spatial 3 --n_frames_D 2 \
|
||||
--use_instance --fg --n_downsample_G 2 --num_D 3 \
|
||||
--max_frames_per_gpu 1 --n_frames_total 4 \
|
||||
--niter_step 2 --niter_fix_global 8 --niter_decay 5 \
|
||||
@ -1,4 +1,4 @@
|
||||
python train.py --name label2city_256 \
|
||||
--loadSize 256 --use_instance --fg \
|
||||
--label_nc 35 --loadSize 256 --use_instance --fg \
|
||||
--n_downsample_G 2 --num_D 1 \
|
||||
--max_frames_per_gpu 6 --n_frames_total 6
|
||||
@ -1,5 +1,5 @@
|
||||
python train.py --name label2city_512_g1 \
|
||||
--loadSize 512 --n_scales_spatial 2 \
|
||||
--label_nc 35 --loadSize 512 --n_scales_spatial 2 \
|
||||
--use_instance --fg --n_downsample_G 2 \
|
||||
--max_frames_per_gpu 2 --n_frames_total 4 \
|
||||
--niter_step 2 --niter_fix_global 8 --niter_decay 5 \
|
||||
@ -1 +0,0 @@
|
||||
python test.py --name label2city_1024_g1 --dataroot datasets/Cityscapes/test_A --loadSize 1024 --n_scales_spatial 3 --use_instance --fg --n_downsample_G 2 --use_single_G
|
||||
@ -1 +0,0 @@
|
||||
python test.py --name label2city_2048 --dataroot datasets/Cityscapes/test_A --loadSize 2048 --n_scales_spatial 3 --use_instance --fg --use_single_G
|
||||
27
test.py
27
test.py
@ -17,7 +17,8 @@ opt.nThreads = 1 # test code only supports nThreads = 1
|
||||
opt.batchSize = 1 # test code only supports batchSize = 1
|
||||
opt.serial_batches = True # no shuffle
|
||||
opt.no_flip = True # no flip
|
||||
opt.dataset_mode = 'test'
|
||||
if opt.dataset_mode == 'temporal':
|
||||
opt.dataset_mode = 'test'
|
||||
|
||||
data_loader = CreateDataLoader(opt)
|
||||
dataset = data_loader.load_data()
|
||||
@ -25,10 +26,7 @@ model = create_model(opt)
|
||||
visualizer = Visualizer(opt)
|
||||
input_nc = 1 if opt.label_nc != 0 else opt.input_nc
|
||||
|
||||
# create website
|
||||
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
|
||||
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
|
||||
|
||||
save_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
|
||||
print('Doing %d frames' % len(dataset))
|
||||
for i, data in enumerate(dataset):
|
||||
if i >= opt.how_many:
|
||||
@ -38,18 +36,19 @@ for i, data in enumerate(dataset):
|
||||
|
||||
_, _, height, width = data['A'].size()
|
||||
A = Variable(data['A']).view(1, -1, input_nc, height, width)
|
||||
B = Variable(data['B']).view(1, -1, opt.output_nc, height, width) if opt.use_real_img else None
|
||||
inst = Variable(data['inst']).view(1, -1, 1, height, width) if opt.use_instance else None
|
||||
B = Variable(data['B']).view(1, -1, opt.output_nc, height, width) if len(data['B'].size()) > 2 else None
|
||||
inst = Variable(data['inst']).view(1, -1, 1, height, width) if len(data['inst'].size()) > 2 else None
|
||||
generated = model.inference(A, B, inst)
|
||||
|
||||
if opt.label_nc != 0:
|
||||
real_A = util.tensor2label(generated[1][0], opt.label_nc)
|
||||
else:
|
||||
real_A = util.tensor2im(generated[1][0,0:1], normalize=False)
|
||||
|
||||
real_A = util.tensor2label(generated[1], opt.label_nc)
|
||||
else:
|
||||
c = 3 if opt.input_nc == 3 else 1
|
||||
real_A = util.tensor2im(generated[1][:c], normalize=False)
|
||||
|
||||
visual_list = [('real_A', real_A),
|
||||
('fake_B', util.tensor2im(generated[0].data[0]))]
|
||||
('fake_B', util.tensor2im(generated[0].data[0]))]
|
||||
visuals = OrderedDict(visual_list)
|
||||
img_path = data['A_paths']
|
||||
img_path = data['A_path']
|
||||
print('process image... %s' % img_path)
|
||||
visualizer.save_images(webpage, visuals, img_path)
|
||||
visualizer.save_images(save_dir, visuals, img_path)
|
||||
56
train.py
56
train.py
@ -7,6 +7,8 @@ import torch
|
||||
from torch.autograd import Variable
|
||||
from collections import OrderedDict
|
||||
from subprocess import call
|
||||
import fractions
|
||||
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0
|
||||
|
||||
from options.train_options import TrainOptions
|
||||
from data.data_loader import CreateDataLoader
|
||||
@ -25,7 +27,10 @@ def train():
|
||||
data_loader = CreateDataLoader(opt)
|
||||
dataset = data_loader.load_data()
|
||||
dataset_size = len(data_loader)
|
||||
print('#training videos = %d' % dataset_size)
|
||||
if opt.dataset_mode == 'pose':
|
||||
print('#training frames = %d' % dataset_size)
|
||||
else:
|
||||
print('#training videos = %d' % dataset_size)
|
||||
|
||||
### initialize models
|
||||
modelG, modelD, flowNet = create_model(opt)
|
||||
@ -50,9 +55,8 @@ def train():
|
||||
else:
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
|
||||
### set parameters
|
||||
bs = opt.batchSize
|
||||
n_gpus = opt.n_gpus_gen // bs # number of gpus used for generator for each batch
|
||||
### set parameters
|
||||
n_gpus = opt.n_gpus_gen // opt.batchSize # number of gpus used for generator for each batch
|
||||
tG, tD = opt.n_frames_G, opt.n_frames_D
|
||||
tDB = tD * opt.output_nc
|
||||
s_scales = opt.n_scales_spatial
|
||||
@ -60,6 +64,7 @@ def train():
|
||||
input_nc = 1 if opt.label_nc != 0 else opt.input_nc
|
||||
output_nc = opt.output_nc
|
||||
|
||||
opt.print_freq = lcm(opt.print_freq, opt.batchSize)
|
||||
total_steps = (start_epoch-1) * dataset_size + epoch_iter
|
||||
total_steps = total_steps // opt.print_freq * opt.print_freq
|
||||
|
||||
@ -89,8 +94,8 @@ def train():
|
||||
for i in range(0, n_frames_total-t_len+1, n_frames_load):
|
||||
# 5D tensor: batchSize, # of frames, # of channels, height, width
|
||||
input_A = Variable(data['A'][:, i*input_nc:(i+t_len)*input_nc, ...]).view(-1, t_len, input_nc, height, width)
|
||||
input_B = Variable(data['B'][:, i*output_nc:(i+t_len)*output_nc, ...]).view(-1, t_len, output_nc, height, width)
|
||||
inst_A = Variable(data['inst'][:, i:i+t_len, ...]).view(-1, t_len, 1, height, width) if opt.use_instance else None
|
||||
input_B = Variable(data['B'][:, i*output_nc:(i+t_len)*output_nc, ...]).view(-1, t_len, output_nc, height, width)
|
||||
inst_A = Variable(data['inst'][:, i:i+t_len, ...]).view(-1, t_len, 1, height, width) if len(data['inst'].size()) > 2 else None
|
||||
|
||||
###################################### Forward Pass ##########################
|
||||
####### generator
|
||||
@ -131,6 +136,9 @@ def train():
|
||||
loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
|
||||
loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_Feat'] + loss_dict['G_VGG']
|
||||
loss_G += loss_dict['G_Warp'] + loss_dict['F_Flow'] + loss_dict['F_Warp'] + loss_dict['W']
|
||||
if opt.add_face_disc:
|
||||
loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
|
||||
loss_D += (loss_dict['D_f_fake'] + loss_dict['D_f_real']) * 0.5
|
||||
|
||||
# collect temporal losses
|
||||
loss_D_T = []
|
||||
@ -165,7 +173,7 @@ def train():
|
||||
############## Display results and errors ##########
|
||||
### print out errors
|
||||
if total_steps % opt.print_freq == 0:
|
||||
t = (time.time() - iter_start_time) / opt.print_freq / opt.batchSize
|
||||
t = (time.time() - iter_start_time) / opt.print_freq
|
||||
errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
|
||||
for s in range(len(loss_dict_T)):
|
||||
errors.update({k+str(s): v.data.item() if not isinstance(v, int) else v for k, v in loss_dict_T[s].items()})
|
||||
@ -173,24 +181,36 @@ def train():
|
||||
visualizer.plot_current_errors(errors, total_steps)
|
||||
|
||||
### display output images
|
||||
if save_fake:
|
||||
if save_fake:
|
||||
if opt.label_nc != 0:
|
||||
input_image = util.tensor2label(real_A[0, -1], opt.label_nc)
|
||||
elif opt.dataset_mode == 'pose':
|
||||
input_image = util.tensor2im(real_A[0, -1, :3], normalize=False)
|
||||
if real_A.size()[2] == 6:
|
||||
input_image2 = util.tensor2im(real_A[0, -1, 3:], normalize=False)
|
||||
input_image[input_image2 != 0] = input_image2[input_image2 != 0]
|
||||
else:
|
||||
input_image = util.tensor2im(real_A[0, -1, :3], normalize=False)
|
||||
c = 3 if opt.input_nc == 3 else 1
|
||||
input_image = util.tensor2im(real_A[0, -1, :c], normalize=False)
|
||||
if opt.use_instance:
|
||||
edges = util.tensor2im(real_A[0, -1, -1:,...], normalize=False)
|
||||
input_image += edges[:,:,np.newaxis]
|
||||
edges = util.tensor2im(real_A[0, -1, -1:,...], normalize=False)
|
||||
input_image += edges[:,:,np.newaxis]
|
||||
|
||||
if opt.add_face_disc:
|
||||
ys, ye, xs, xe = modelD.module.get_face_region(real_A[0, -1:])
|
||||
if ys is not None:
|
||||
input_image[ys, xs:xe, :] = input_image[ye, xs:xe, :] = input_image[ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255
|
||||
|
||||
visual_list = [('input_image', input_image),
|
||||
('fake_image', util.tensor2im(fake_B[0, -1])),
|
||||
('fake_first_image', util.tensor2im(fake_B_first)),
|
||||
('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
|
||||
('real_image', util.tensor2im(real_B[0, -1])),
|
||||
('flow', util.tensor2flow(flow[0, -1])),
|
||||
('real_image', util.tensor2im(real_B[0, -1])),
|
||||
('flow_ref', util.tensor2flow(flow_ref[0, -1])),
|
||||
('conf_ref', util.tensor2im(conf_ref[0, -1], normalize=False)),
|
||||
('weight', util.tensor2im(weight[0, -1], normalize=False))]
|
||||
('conf_ref', util.tensor2im(conf_ref[0, -1], normalize=False))]
|
||||
if flow is not None:
|
||||
visual_list += [('flow', util.tensor2flow(flow[0, -1])),
|
||||
('weight', util.tensor2im(weight[0, -1], normalize=False))]
|
||||
visuals = OrderedDict(visual_list)
|
||||
visualizer.display_current_results(visuals, epoch, total_steps)
|
||||
|
||||
@ -227,7 +247,7 @@ def train():
|
||||
### gradually grow training sequence length
|
||||
if (epoch % opt.niter_step) == 0:
|
||||
data_loader.dataset.update_training_batch(epoch//opt.niter_step)
|
||||
modelG.module.update_training_batch(epoch//opt.niter_step)
|
||||
modelG.module.update_training_batch(epoch//opt.niter_step)
|
||||
|
||||
### finetune all scales
|
||||
if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
|
||||
@ -236,8 +256,10 @@ def train():
|
||||
def reshape(tensors):
|
||||
if isinstance(tensors, list):
|
||||
return [reshape(tensor) for tensor in tensors]
|
||||
if tensors is None:
|
||||
return None
|
||||
_, _, ch, h, w = tensors.size()
|
||||
return tensors.view(-1, ch, h, w)
|
||||
return tensors.contiguous().view(-1, ch, h, w)
|
||||
|
||||
# get temporally subsampled frames for real/fake sequences
|
||||
def get_skipped_frames(B_all, B, t_scales, tD):
|
||||
|
||||
91
util/util.py
91
util/util.py
@ -65,63 +65,10 @@ def tensor2flow(output, imtype=np.uint8):
|
||||
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
|
||||
return rgb
|
||||
|
||||
def make_anaglyph(imL, imR):
|
||||
lRed, lGreen, lBlue = imL[:,:,0], imL[:,:,1], imL[:,:,2]
|
||||
rRed, rGreen, rBlue = imR[:,:,0], imR[:,:,1], imR[:,:,2]
|
||||
return np.dstack((rRed, lGreen, lBlue))
|
||||
|
||||
def ycbcr2rgb(img_y, img_cb, img_cr):
|
||||
im = np.dstack((img_y, img_cb, img_cr))
|
||||
xform = np.array([[1, 0, 1.402], [1, -0.34414, -.71414], [1, 1.772, 0]])
|
||||
rgb = im.astype(np.float)
|
||||
rgb[:,:,[1,2]] -= 128
|
||||
return np.uint8(np.clip(rgb.dot(xform.T), 0, 255))
|
||||
|
||||
def rgb2yuv(R, G, B):
|
||||
Y = 0.299*R + 0.587*G + 0.114*B
|
||||
U = -0.147*R - 0.289*G + 0.436*B
|
||||
V = 0.615*R - 0.515*G - 0.100*B
|
||||
return Y, U, V
|
||||
|
||||
def yuv2rgb(Y, U, V):
|
||||
R = (Y + 1.14 * V)
|
||||
G = (Y - 0.39 * U - 0.58 * V)
|
||||
B = (Y + 2.03 * U)
|
||||
return R, G, B
|
||||
|
||||
def diagnose_network(net, name='network'):
|
||||
mean = 0.0
|
||||
count = 0
|
||||
for param in net.parameters():
|
||||
if param.grad is not None:
|
||||
mean += torch.mean(torch.abs(param.grad.data))
|
||||
count += 1
|
||||
if count > 0:
|
||||
mean = mean / count
|
||||
print(name)
|
||||
print(mean)
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
|
||||
def info(object, spacing=10, collapse=1):
|
||||
"""Print methods and doc strings.
|
||||
Takes module, class, list, dictionary, or string."""
|
||||
methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
|
||||
processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
|
||||
print( "\n".join(["%s %s" %
|
||||
(method.ljust(spacing),
|
||||
processFunc(str(getattr(object, method).__doc__)))
|
||||
for method in methodList]) )
|
||||
|
||||
def varname(p):
|
||||
for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
|
||||
m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
|
||||
if m:
|
||||
return m.group(1)
|
||||
|
||||
def print_numpy(x, val=True, shp=False):
|
||||
x = x.astype(np.float64)
|
||||
if shp:
|
||||
@ -131,7 +78,6 @@ def print_numpy(x, val=True, shp=False):
|
||||
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
||||
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
@ -139,7 +85,6 @@ def mkdirs(paths):
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
@ -149,43 +94,22 @@ def uint82bin(n, count=8):
|
||||
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
|
||||
|
||||
def labelcolormap(N):
|
||||
if N == 35: # GTA/cityscape train
|
||||
if N == 35: # Cityscapes train
|
||||
cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
|
||||
(128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
|
||||
(180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
|
||||
(107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
|
||||
( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
|
||||
dtype=np.uint8)
|
||||
elif N == 20: # GTA/cityscape eval
|
||||
elif N == 20: # Cityscapes eval
|
||||
cmap = np.array([(128, 64,128), (244, 35,232), ( 70, 70, 70), (102,102,156), (190,153,153), (153,153,153), (250,170, 30),
|
||||
(220,220, 0), (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142),
|
||||
( 0, 0, 70), ( 0, 60,100), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0, 0)],
|
||||
dtype=np.uint8)
|
||||
elif N == 23: # Synthia
|
||||
cmap = np.array([(0, 0, 0 ), (70, 130,180), (70, 70, 70 ), (128,64, 128), (244,35, 232), (64, 64, 128), (107,142,35 ),
|
||||
(153,153,153), (0, 0, 142), (220,220,0 ), (220,20, 60 ), (119,11, 32 ), (0, 0, 230), (250,170,160),
|
||||
(128,64, 64 ), (250,170,30 ), (152,251,152), (255,0, 0 ), (0, 0, 70 ), (0, 60, 100), (0, 80, 100),
|
||||
(102,102,156), (102,102,156)],
|
||||
dtype=np.uint8)
|
||||
elif N == 32: # new GTA train
|
||||
cmap = np.array([(0, 0, 0), (111, 74, 0), (70, 130, 180), (128, 64, 128), (244, 35, 232), (230, 150, 140), (152, 251, 152),
|
||||
(87, 182, 35), (35, 142, 35), (70, 70, 70), (153, 153, 153), (190, 153, 153), (150, 20, 20), (250, 170, 30),
|
||||
(220, 220, 0), (180, 180, 100), (173, 153, 153), (168, 153, 153), (81, 0, 21), (81, 0, 81), (220, 20, 60),
|
||||
(255, 0, 0), (119, 11, 32), (0, 0, 230), (0, 0, 142), (0, 80, 100), (0, 60, 100), (0, 0 , 70),
|
||||
(0, 0, 90), (0, 80, 100), (0, 100, 100), (50, 0, 90)],
|
||||
dtype=np.uint8)
|
||||
elif N == 24: # new GTA eval
|
||||
cmap = np.array([(70, 130, 180), (128, 64, 128), (244, 35, 232), (152, 251, 152), (87, 182, 35), (35, 142, 35), (70, 70, 70),
|
||||
(153, 153, 153), (190, 153, 153), (150, 20, 20), (250, 170, 30), (220, 220, 0), (180, 180, 100), (173, 153, 153),
|
||||
(168, 153, 153), (81, 0, 21), (81, 0, 81), (220, 20, 60), (0, 0, 230), (0, 0, 142), (0, 80, 100),
|
||||
(0, 60, 100), (0, 0 , 70), (0, 0, 0)],
|
||||
dtype=np.uint8)
|
||||
elif N == 154 or N == 11 or N == 151 or N == 233:
|
||||
else:
|
||||
cmap = np.zeros((N, 3), dtype=np.uint8)
|
||||
for i in range(N):
|
||||
r = 0
|
||||
g = 0
|
||||
b = 0
|
||||
r, g, b = 0, 0, 0
|
||||
id = i
|
||||
for j in range(7):
|
||||
str_id = uint82bin(id)
|
||||
@ -193,16 +117,11 @@ def labelcolormap(N):
|
||||
g = g ^ (np.uint8(str_id[-2]) << (7-j))
|
||||
b = b ^ (np.uint8(str_id[-3]) << (7-j))
|
||||
id = id >> 3
|
||||
cmap[i, 0] = r
|
||||
cmap[i, 1] = g
|
||||
cmap[i, 2] = b
|
||||
else:
|
||||
raise NotImplementedError('Colorization for label number [%s] is not recognized' % N)
|
||||
cmap[i, 0], cmap[i, 1], cmap[i, 2] = r, g, b
|
||||
return cmap
|
||||
|
||||
def colormap(n):
|
||||
cmap = np.zeros([n, 3]).astype(np.uint8)
|
||||
|
||||
for i in np.arange(n):
|
||||
r, g, b = np.zeros(3)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import numpy as np
|
||||
import os
|
||||
import ntpath
|
||||
import time
|
||||
from . import util
|
||||
from . import html
|
||||
@ -112,15 +111,16 @@ class Visualizer():
|
||||
log_file.write('%s\n' % message)
|
||||
|
||||
# save image to the disk
|
||||
def save_images(self, webpage, visuals, image_path):
|
||||
image_dir = webpage.get_image_dir()
|
||||
short_path = ntpath.basename(image_path[0])
|
||||
name = os.path.splitext(short_path)[0]
|
||||
def save_images(self, image_dir, visuals, image_path, webpage=None):
|
||||
dirname = os.path.basename(os.path.dirname(image_path[0]))
|
||||
image_dir = os.path.join(image_dir, dirname)
|
||||
util.mkdir(image_dir)
|
||||
name = os.path.basename(image_path[0])
|
||||
name = os.path.splitext(name)[0]
|
||||
|
||||
webpage.add_header(name)
|
||||
ims = []
|
||||
txts = []
|
||||
links = []
|
||||
if webpage is not None:
|
||||
webpage.add_header(name)
|
||||
ims, txts, links = [], [], []
|
||||
|
||||
for label, image_numpy in visuals.items():
|
||||
save_ext = 'png' if 'real_A' in label and self.opt.label_nc != 0 else 'jpg'
|
||||
@ -128,10 +128,12 @@ class Visualizer():
|
||||
save_path = os.path.join(image_dir, image_name)
|
||||
util.save_image(image_numpy, save_path)
|
||||
|
||||
ims.append(image_name)
|
||||
txts.append(label)
|
||||
links.append(image_name)
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
if webpage is not None:
|
||||
ims.append(image_name)
|
||||
txts.append(label)
|
||||
links.append(image_name)
|
||||
if webpage is not None:
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
|
||||
def vis_print(self, message):
|
||||
print(message)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user