1D adv-diff: LSPG with nonlinear manifold projection via MLP

Overview

This demo solves the same problem as the one here, but instead of using POD modes, we show here how to use a nonlinear manifold computed approximated by a neural network. Specifically, we use a MLP with 2 hidden layers of sizes 64 and 200.

Main function

The main function of the demo is the following:

logger.initialize(logger.logto.terminal)
logger.setVerbosity([logger.loglevel.info])

# create fom object
fomObj = AdvDiff1d(nGrid=120, adv_coef=2.0)

# the final time to integrate to
finalTime = .05

#--- 1. FOM ---#
fomTimeStepSize  = 1e-5
fomNumberOfSteps = int(finalTime/fomTimeStepSize)
sampleEvery      = 200
[fomFinalState, snapshots] = doFom(fomObj, fomTimeStepSize, fomNumberOfSteps, sampleEvery)

#--- 2. train a nonlinear mapping using PyTorch ---#
# here we use 3 modes, change this to try different modes
myNonLinearMapper = trainMapping(snapshots, romSize=3, epochs=500)

#--- 3. LSPG ROM ---#
romTimeStepSize  = 3e-4
romNumberOfSteps = int(finalTime/romTimeStepSize)
approximatedState = runLspg(fomObj, romTimeStepSize, romNumberOfSteps, myNonLinearMapper)

# compute l2-error between fom and approximate state
fomNorm = linalg.norm(fomFinalState)
err = linalg.norm(fomFinalState-approximatedState)
print("Final state relative l2 error: {}".format(err/fomNorm))

logger.finalize()

1. Run FOM and collect snapshots

This step is the same as described here,

2. Setup and train the nonlinear mapper

It is important to note that while the mapper class below has the API required by pressio4py, it can encapsulate any arbitrary mapping function. In this case we show how to create a MLP-based representation in PyTorch, but one can use any other types of mapping and any other library (e.g., Tensorflow, keras). All of the PyTorch-specific code is encapsulated here. If you prefer Tensorflow/keras, an equivalent implementation is here.

The autoencoder is defined by

class myAutoencoder(torch.nn.Module):
  def __init__(self, fomSize, romSize=10):
    super(myAutoencoder, self).__init__()
    self.encoder = myEncoder(fomSize, romSize)
    self.decoder = myDecoder(fomSize, romSize)

  def forward(self, x):
    code = self.encoder(x)
    x = self.decoder(code)
    return x, code

  def train(self, dataloader, optimizer, n_epochs, loss=torch.nn.MSELoss()):
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.8, min_lr=1e-6)
    for epoch in range(n_epochs):
      total_train_loss = 0.0
      for data,label in dataloader:
        optimizer.zero_grad()
        output, latent = self.forward(data)
        loss_size = loss(output, label)
        loss_size.backward()
        optimizer.step()
      total_train_loss += loss_size.item()
      scheduler.step(total_train_loss)

class myEncoder(torch.nn.Module):
  def __init__(self, fomSize, romSize):
    super(myEncoder, self).__init__()
    self.fc1 = torch.nn.Linear(fomSize, 200)
    self.fc2 = torch.nn.Linear(200, 64)
    self.fc3 = torch.nn.Linear(64, romSize)

  def forward(self, x):
    x = self.fc1(x)
    x = F.elu(x)
    x = self.fc2(x)
    x = F.elu(x)
    x = self.fc3(x)
    x = F.elu(x)
    return x

class myDecoder(torch.nn.Module):
  def __init__(self, fomSize, romSize):
    super(myDecoder, self).__init__()
    self.romSize_ = romSize
    self.fomSize_ = fomSize
    self.fc1 = torch.nn.Linear(romSize, 64)
    self.fc2 = torch.nn.Linear(64, 200)
    self.fc3 = torch.nn.Linear(200, fomSize)

  def forward(self, x):
    x = self.fc1(x)
    x = F.elu(x)
    x = self.fc2(x)
    x = F.elu(x)
    x = self.fc3(x)
    return x

and is created/trained using

def trainMapping(snapshots, romSize, epochs, enable_restart=False):
  fomSize = snapshots.shape[0]
  model = myAutoencoder(fomSize, romSize)
  optimizer = optim.AdamW(model.parameters(), lr=5e-3)

  if enable_restart:
    if pathlib.Path('TrainingCheckpoint.tar').is_file():
      print("Loading checkpoint")
      checkpoint = torch.load('TrainingCheckpoint.tar')
      model.load_state_dict(checkpoint['model_state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  samples = torch.utils.data.TensorDataset(torch.Tensor(snapshots.T), torch.Tensor(snapshots.T))
  loader = torch.utils.data.DataLoader(samples, batch_size=500, shuffle=True)
  model.train(loader, optimizer, n_epochs=epochs)

  if enable_restart:
    torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
               },
               'TrainingCheckpoint.tar')

  return MyMapper(model.decoder, model.encoder)

This is all wrapped in a mapper class which conforms to the API required by Pressio

class MyMapper:
  def __init__(self, decoderObj, encoderObj):
    self.decoder_ = decoderObj
    self.encoder_ = encoderObj

    self.numModes_ = decoderObj.romSize_
    fomSize = decoderObj.fomSize_
    self.fomState0 = np.zeros(fomSize)
    self.fomState1 = np.zeros(fomSize)
    # attention: the jacobian of the mapping must be column-major oder
    # so that pressio can view it without deep copying it, this enables
    # to keep only one jacobian object around and to call the update
    # method below correctly
    self.jacobian_ = np.zeros((fomSize,self.numModes_), order='F')

  def jacobian(self): return self.jacobian_

  def applyMapping(self, romState, fomState):
    fomState[:] = self.decoder_(torch.Tensor(romState)).detach().numpy()

  def applyInverseMapping(self, fomState):
    romState = np.zeros(self.numModes_)
    romState[:] = self.encoder_(torch.Tensor(fomState)).detach()[:]
    return romState

  def updateJacobian(self, romState):
    self.updateJacobianFD(romState)

  def updateJacobianFD(self, romState):
    # finite difference to approximate jacobian of the mapping
    romStateLocal = romState.copy()
    self.applyMapping(romStateLocal,self.fomState0)
    eps = 0.001
    for i in range(self.numModes_):
        romStateLocal[i] += eps
        self.applyMapping(romStateLocal, self.fomState1)
        self.jacobian_[:,i] = (self.fomState1 - self.fomState0) / eps
        romStateLocal[i] -= eps

  def updateJacobianExact(self, romState):
    # use pytorch autodifferentiation to compute jacobian of the mapping
    # slower than finite difference currently
    J = torch.autograd.functional.jacobian(self.decoder_, torch.Tensor(romState))
    self.jacobian_[:,:] = J.detach()[:,:]

3. Construct and run LSPG

def runLspg(fomObj, dt, nsteps, customMapper):
  # this is an auxiliary class that can be passed to solve
  # LSPG to monitor the rom state.
  class RomStateObserver:
    def __call__(self, timeStep, time, state): pass

  # this linear solver is used at each gauss-newton iteration
  class MyLinSolver:
    def solve(self, A,b,x):
      lumat, piv, info = linalg.lapack.dgetrf(A, overwrite_a=True)
      x[:], info = linalg.lapack.dgetrs(lumat, piv, b, 0, 0)

  #----------------------------------------
  # create a custom decoder using the mapper passed as argument
  customDecoder = rom.Decoder(customMapper, "MyMapper")

  # fom reference state: here it is zero
  fomReferenceState = np.zeros(fomObj.nGrid)

  # create ROM state by projecting the fom initial condition
  fomInitialState = fomObj.u0.copy()
  romState = customMapper.applyInverseMapping(fomInitialState)

  # create LSPG problem
  scheme = ode.stepscheme.BDF1
  problem = rom.lspg.unsteady.DefaultProblem(scheme, fomObj, customDecoder, romState, fomReferenceState)

  # create the Gauss-Newton solver
  nonLinSolver = solvers.create_gauss_newton(problem, romState, MyLinSolver())
  # set tolerance and convergence criteria
  nlsTol, nlsMaxIt = 1e-7, 10
  nonLinSolver.setMaxIterations(nlsMaxIt)
  nonLinSolver.setStoppingCriterion(solvers.stop.WhenCorrectionAbsoluteNormBelowTolerance)

  # create object to monitor the romState at every iteration
  myObs = RomStateObserver()
  # solve problem
  ode.advance_n_steps_and_observe(problem, romState, 0., dt, nsteps, myObs, nonLinSolver)

  # after we are done, use the reconstructor object to reconstruct the fom state
  # get the reconstructor object: this allows to map romState to fomState
  fomRecon = problem.fomStateReconstructor()
  return fomRecon(romState)

Results

If everything works fine, the following plot shows the result.

Image