Skip to content

DeepWave-KAUST/sweep

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

85 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SWEEP

Seismic Wave Equation Exploration Platform (SWEEP) is a Python package designed for seismic wave equation modeling and inversion.

** Note: From this version on, lazy imports are supported. You no longer need to install both JAX and PyTorch—you only need to install whichever backend you intend to use.

Installation

python -m build
pip install dist/*.whl

Usage

The following example shows how to compute the gradient of the a toy model with respect to the velocity model.

import torch
# import jax
torch.backends.cudnn.benchmark = True
from sweep.propagator.torch import PropTorch
# from sweep.propagator.jax import PropJax
from sweep.equations import Acoustic
from sweep.signal import ricker
import numpy as np
import matplotlib.pyplot as plt

# Model parameters
nt = 1500
dt = 0.002
dh = 10
delay = 0.1
fm = 5
spatial_order = 8
shape = (100,100)

# Device
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Create a 2-layer model
true_model = np.ones(shape, dtype=np.float32)*1500
true_model[50:, :] = 2000

# Create a model
model = PropTorch(Acoustic(spatial_order=spatial_order, device=dev, backend='torch'), 
            shape=shape, 
            dev=dev, 
            dh=dh,
            dt=dt,
            source_type=['h1'],
            receiver_type=['h1'],
            pml_type='cpmlr',
            free_surface=False)
            
# Set the model parameters (Pytorch)
vp = torch.from_numpy(true_model).to(dev).requires_grad_(True)
# Set the model parameters (Jax)
# model.set_parameters([jnp.array(true_model)])
# Create a wavelet
t = np.arange(0, int(nt//2)*dt, dt)
wave = ricker(t-delay, f=fm)

# Acquicition geometry
sources = np.array([[1, 1]]) # in grid, shape=(nshots, 2)
receivers = np.array([[[99, 1]]]) # in grid, shape=(nshots, nreceivers, 2)

# Forward modeling
# Backward propagation (Pytorch)
obs = model.forward(wave, sources, receivers, models=[vp])
obs.pow(2).sum().backward()
# Backward propagation (Jax)
# def fwi(vp):
#     return (model(wave, sources, receivers, models=[vp])**2).sum()
# grad = jax.grad(fwi)(model.vp)

# Show the results
fig, axes=plt.subplots(1,3, figsize=(12,3))

axes[0].imshow(true_model, cmap='seismic', aspect='auto')
axes[0].set_title('True model')
axes[1].plot(obs.detach().cpu().numpy().squeeze(), label='Observed data')
grad = vp.grad.detach().cpu().numpy() # Pytorch
# grad = jax.device_get(grad) # Jax
vmin,vmax=np.percentile(grad, [1,99])
axes[1].set_title('Observed data')
axes[2].imshow(grad, cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
axes[2].set_title('Gradient of vp')
fig.tight_layout()
fig.savefig('grad_vp.png', dpi=300, bbox_inches='tight')
plt.close()

The ground truth model, observed data and the gradient of the velocity model are shown below. grad_vp

Examples

Some examples are provided in the examples folder (Still working on it since some of the APIs are changed).

License

This project is licensed under the MIT License - see the LICENSE file for details

About

Scalable GPU seismic imaging toolkit for wave simulation, RTM, and FWI.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors