Example: Optimization and expand

In this example, we optimize a tiny octree’s output to the RGB color vector [0, 1, 0.5], for a fixed ray.

We start with data format SH1 and in the middle switch to SH4 using the expand(format) function, which automatically inserts extra channels as appropriate. Then we continue to optimize using a manual gradient descent with MSE. Slowly, the results get closer to the target vector.

import svox
import torch

device = 'cuda:0'
t = svox.N3Tree(device=device, data_format="SH1")

t[0, 0, 0, :-1] = 0.0
t[0, 0, 0, -1:] = 0.5
r = svox.VolumeRenderer(t)

target =  torch.tensor([[0.0, 1.0, 0.5]], device=device)

ray_ori = torch.tensor([[0.1, 0.1, -0.1]], device=device)
ray_dir = torch.tensor([[0.0, 0.0, 1.0]], device=device)
ray = svox.Rays(origins=ray_ori, dirs=ray_dir, viewdirs=ray_dir)

lr = 1e2

print('GRADIENT DESC')

for i in range(20):
    rend = r(ray, cuda=True)
    if i % 2 == 0:
        print(rend.detach()[0].cpu().numpy())
    ((rend - target) ** 2).sum().backward()
    t.data.data -= lr * t.data.grad
    t.zero_grad()

print('Expanding..')
t.expand("SH4")
print(r.data_format)
for i in range(20):
    rend = r(ray, cuda=True)
    if i % 2 == 0:
        print(rend.detach()[0].cpu().numpy())
    ((rend - target) ** 2).sum().backward()
    t.data.data -= lr * t.data.grad
    t.zero_grad()

print('TARGET')
print(target[0].cpu().numpy())

The output:

GRADIENT DESC
[0.88920575 0.88920575 0.88920575]
[0.67369866 0.6859846  0.67984015]
[0.6194525  0.65586865 0.63762873]
[0.58019906 0.64437073 0.61214054]
[0.5475207  0.6409838  0.59386927]
[0.5188446 0.6420485 0.579674 ]
[0.49309036 0.64582    0.5681365 ]
[0.46970066 0.6513118  0.55849427]
[0.4483344 0.657904  0.5502867]
[0.42875046 0.66518104 0.54321104]
Expanding..
SH4
[0.4107593 0.6728529 0.5370555]
[0.3631369  0.71049845 0.5277597 ]
[0.32639033 0.7405325  0.52003586]
[0.29751268 0.7646569  0.51378375]
[0.27432522 0.7842779  0.5088086 ]
[0.25531954 0.80046684 0.50490075]
[0.23945224 0.8140159  0.50186735]
[0.22599061 0.8255081  0.4995423 ]
[0.21440998 0.83537465 0.4977861 ]
[0.2043267  0.84393847 0.49648416]
TARGET
[0.  1.  0.5]