CVPR 2023 (✨Highlight)
Linyi Jin1, Jianming Zhang2, Yannick Hold-Geoffroy2, Oliver Wang2, Kevin Matzen2, Matthew Sticha1, David Fouhey1
1University of Michigan, 2Adobe ResearchWe propose Perspective Fields as a representation that models the local perspective properties of an image. Perspective Fields contain per-pixel information about the camera view, parameterized as an up vector and a latitude value.
📷 From Perspective Fields, you can also get camera parameters if you assume certain camera models. We provide models to recover camera roll, pitch, fov and principal point location.
This branch is for training / evaluation, which we use Detectron2 framework. If you want to run inference with minimal package dependency, please checkout the main branch.
- Note
- Environment Setup
- Training
- Testing
- Model Zoo
- Coordinate Frame
- Inference
- Camera Parameters to Perspective Fields
- Visualize Perspective Fields
- Citation
- Acknowledgment
PerspectiveFields requires python >= 3.8 and PyTorch.
| Pro tip: use mamba in place of conda for much faster installs. The dependencies can be installed by running:
git clone git@github.com:jinlinyi/PerspectiveFields.git
# create virtual env
conda create -n perspective python=3.9
conda activate perspective
# install pytorch compatible to your system https://pytorch.org/get-started/previous-versions/
# conda install pytorch torchvision cudatoolkit -c pytorch
conda install pytorch=1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
# conda packages
conda install -c conda-forge openexr-python openexr
# pip packages
pip install -r requirements.txt
# install mmcv with mim, I encountered some issue with pip install mmcv :(
mim install mmcv
# install Perspective Fields.
pip install -e .[Click to expand]
We used Google street view dataset to train the ParamNet in our paper Table 3, 4. Download GSV datasets:
wget https://www.dropbox.com/s/plcmcza8vfmmpkm/google_street_view_191210.tar
wget https://www.dropbox.com/s/9se3lrpljd59cod/gsv_test_crop_uniform.tarExtract the dataset under perspectiveField/datasets.
Note that we used images from 360cities to train the PerspectiveNet in Table 1.
Download initial weights from segformer.b3.512x512.ade.160k.pth:
wget https://www.dropbox.com/s/0axxpfga265gq3o/ade_pretrained.pthPlace it under perspectiveField/init_model_weights.
- We first trained PerspectiveNet:
python -W ignore train.py \
--config-file configs/config-mix-gsv-regress.yaml \
--num-gpus 2 \
--dist-url tcp://127.0.0.1:$((RANDOM +10000)) \
OUTPUT_DIR "./exp/step01-gsv-perspective-pretrain" \
SOLVER.IMS_PER_BATCH 64- Then we trained the ParamNet, you can download the model from the previous step here:
wget https://www.dropbox.com/s/c9199n5lmy30tob/gsv_persnet_pretrain.pth- To train the ParamNet to predict roll, pitch and fov:
python -W ignore train.py \
--config-file configs/config-gsv-rpf.yaml \
--num-gpus 2 \
--dist-url tcp://127.0.0.1:$((RANDOM +10000)) \
OUTPUT_DIR "./exp/step02-gsv-paramnet-rpf"- To train the ParamNet to predict roll, pitch, fov, and principal point:
python -W ignore train.py \
--config-file configs/config-gsv-rpfpp.yaml \
--num-gpus 2 \
--dist-url tcp://127.0.0.1:$((RANDOM +10000)) \
OUTPUT_DIR "./exp/step02-gsv-paramnet-rpfpp"[Click to expand]
In our paper, we tested the PersNet-360Cities model on images from publicly available datasets Stanford2D3D and TartanAir. Results can be found in Table 1.
To download Stanford2d3d dataset: First agree to their data sharing and usage term: link.
https://www.dropbox.com/sh/ycd4hv0t1nqagub/AACjqZ2emGw7L-aAJ1rmpX4-aDownload TartanAir dataset:
https://www.dropbox.com/sh/7tev8uqnnjfhzhb/AAD9y_d1DCcoZ-AQDEQ1tn0UaExtract the datasets under perspectiveField/datasets.
We also tested the PersNet_paramnet-GSV-centered and PersNet_Paramnet-GSV-uncentered models on centered and uncentered images from Google Street View (GSV). Results can be found in Tables 3 and 4.
Download GSV datasets:
wget https://www.dropbox.com/s/plcmcza8vfmmpkm/google_street_view_191210.tar
wget https://www.dropbox.com/s/9se3lrpljd59cod/gsv_test_crop_uniform.tarExtract the datasets under perspectiveField/datasets.
- PerspectiveNet:
First, to test PerspectiveNet, provide a dataset name corresponding to a name/path pair from perspective2d/data/datasets/builtin.py. Create and provide an output folder under perspectiveField/exps. Choose a model and provide the path to the config file and weights, both of which should be under perspectiveField/models.
Example:
python -W ignore demo/test_persfield.py \
--dataset stanford2d3d_test \
--config-file ./models/cvpr2023.yaml \
--opts MODEL.WEIGHTS ./models/cvpr2023.pth- ParamNet:
To test ParamNet, again provide a dataset name, output folder, and a path to config and model weights, just as with PerspectiveNet.
Example:
python -W ignore demo/test_param_network.py \
--dataset gsv_test \
--output ./exps/paramnet_gsv_test \
--config-file ./models/paramnet_gsv_rpf.yaml \
--opts MODEL.WEIGHTS ./models/paramnet_gsv_rpf.pthpython -W ignore demo/test_param_network.py \
--dataset gsv_test_crop_uniform \
--output ./exps/paramnet_gsv_test \
--config-file ./models/paramnet_gsv_rpfpp.yaml \
--opts MODEL.WEIGHTS ./models/paramnet_gsv_rpfpp.pthNOTE: Extract model weights under perspectiveField/models.
| Model Name and Weights | Training Dataset | Config File | Outputs | Expected input |
|---|---|---|---|---|
| [NEW]Paramnet-360Cities-edina-centered | 360cities and EDINA | paramnet_360cities_edina_rpf.yaml | Perspective Field + camera parameters (roll, pitch, vfov) | Uncropped, indoor🏠, outdoor🏙️, natural🌳, and egocentric👋 data |
| [NEW]Paramnet-360Cities-edina-uncentered | 360cities and EDINA | paramnet_360cities_edina_rpfpp.yaml | Perspective Field + camera parameters (roll, pitch, vfov, cx, cy) | Cropped, indoor🏠, outdoor🏙️, natural🌳, and egocentric👋 data |
| PersNet-360Cities | 360cities | cvpr2023.yaml | Perspective Field | Indoor🏠, outdoor🏙️, and natural🌳 data. |
| PersNet_paramnet-GSV-centered | GSV | paramnet_gsv_rpf.yaml | Perspective Field + camera parameters (roll, pitch, vfov) | Uncropped, street view🏙️ data. |
| PersNet_Paramnet-GSV-uncentered | GSV | paramnet_gsv_rpfpp.yaml | Perspective Field + camera parameters (roll, pitch, vfov, cx, cy) | Cropped, street view🏙️ data. |
yaw / azimuth: camera rotation about the y-axis
pitch / elevation: camera rotation about the x-axis
roll: camera rotation about the z-axis
Extrinsics: rotz(roll).dot(rotx(elevation)).dot(roty(azimuth))
- Live Demo 🤗.
- We also provide notebook to Predict Perspective Fields and Recover Camera Parameters.
- Alternatively, you can also run
demo.py:
python demo/demo.py \
--config-file <config-path> \ #../jupyter-notebooks/models/cvpr2023.yaml
--input <input-path> \ #../assets/imgs
--output <output-path> \ #debug
--opts MODEL.WEIGHTS <ckpt-path> #../jupyter-notebooks/models/cvpr2023.pthCheckout Jupyter Notebook. Perspective Fields can be calculated from camera parameters. If you prefer, you can also manually calculate the corresponding Up-vector and Latitude map by following Equations 1 and 2 in our paper. Our code currently supports:
- Pinhole model [Hartley and Zisserman 2004] (Perspective Projection)
from perspective2d.utils.panocam import PanoCam
# define parameters
roll = 0
pitch = 20
vfov = 70
width = 640
height = 480
# get Up-vectors.
up = PanoCam.get_up(np.radians(vfov), width, height, np.radians(pitch), np.radians(roll))
# get Latitude.
lati = PanoCam.get_lat(np.radians(vfov), width, height, np.radians(pitch), np.radians(roll))- Unified Spherical Model [Barreto 2006; Mei and Rives 2007] (Distortion).
xi = 0.5 # distortion parameter from Unified Spherical Model
x = -np.sin(np.radians(vfov/2))
z = np.sqrt(1 - x**2)
f_px_effective = -0.5*(width/2)*(xi+z)/x
crop, _, _, _, up, lat, xy_map = PanoCam.crop_distortion(equi_img,
f=f_px_effective,
xi=xi,
H=height,
W=width,
az=yaw, # degrees
el=-pitch,
roll=-roll)We provide a one-line code to blend Perspective Fields onto input image.
import matplotlib.pyplot as plt
from perspective2d.utils import draw_perspective_fields
# Draw up and lati on img. lati is in radians.
blend = draw_perspective_fields(img, up, lati)
# visualize with matplotlib
plt.imshow(blend)
plt.show()Perspective Fields can serve as an easy visual check for correctness of the camera parameters.
- For example, we can visualize the Perspective Fields based on calibration results from this awesome repo.
-
Left: We plot the perspective fields based on the numbers printed on the image, they look accurate😊;
-
Mid: If we try a number that is 10% off (0.72*0.9=0.648), we see mismatch in Up directions at the top right corner;
-
Right: If distortion is 20% off (0.72*0.8=0.576), the mismatch becomes more obvious.
If you find this code useful, please consider citing:
@inproceedings{jin2023perspective,
title={Perspective Fields for Single Image Camera Calibration},
author={Linyi Jin and Jianming Zhang and Yannick Hold-Geoffroy and Oliver Wang and Kevin Matzen and Matthew Sticha and David F. Fouhey},
booktitle = {CVPR},
year={2023}
}
This work was partially funded by the DARPA Machine Common Sense Program. We thank authors from A Deep Perceptual Measure for Lens and Camera Calibration for releasing their code on Unified Spherical Model.









