Generate Lorenz attractor data¶
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def lorenz_attractor(x, y, z, s=10, r=28, b=2.667): #physic s,r,b are physics parameter https://en.wikipedia.org/wiki/Lorenz_system
dx = s * (y - x)
dy = x * (r - z) - y
dz = x * y - b * z
return dx, dy, dz
# Time parameters
dt = 0.01 # Time step, difference in time between one point and another
num_steps = 18251 # Number of time steps
# Initial conditions
x = np.ones(num_steps + 1)
y = np.ones(num_steps + 1)
z = np.ones(num_steps + 1)
# Set initial values
x[0], y[0], z[0] = 0.0, 1.0, 1.05
# Generate Lorenz attractor data
for i in range(num_steps):
dx, dy, dz = lorenz_attractor(x[i], y[i], z[i])
x[i + 1] = x[i] + dx * dt
y[i + 1] = y[i] + dy * dt
z[i + 1] = z[i] + dz * dt
# Create a dataframe with the Lorenz attractor data
df= pd.DataFrame({'X': x, 'Y': y, 'Z': z})
ax = plt.axes(projection='3d')
# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(df.X, df.Y, df.Z, color='red', linewidth=0.2 )
df
| X | Y | Z | |
|---|---|---|---|
| 0 | 0.000000 | 1.000000 | 1.050000 |
| 1 | 0.100000 | 0.990000 | 1.021996 |
| 2 | 0.189000 | 1.007078 | 0.995730 |
| 3 | 0.270808 | 1.048045 | 0.971077 |
| 4 | 0.348532 | 1.110761 | 0.948017 |
| ... | ... | ... | ... |
| 18247 | -1.866625 | -3.365286 | 19.796920 |
| 18248 | -2.016491 | -3.484754 | 19.331754 |
| 18249 | -2.163317 | -3.624700 | 18.886446 |
| 18250 | -2.309455 | -3.785608 | 18.461158 |
| 18251 | -2.457071 | -3.968048 | 18.056226 |
18252 rows Ć 3 columns
To help the model capture the local dynamics of the trajectory, I include first-order derivatives as additional input features. Each time step is therefore represented by six values: $x$, $y$, $z$, $\dot{x}$, $\dot{y}$, and $\dot{z}$.
dX_dt = np.gradient(df.X, dt)
dY_dt = np.gradient(df.Y, dt)
dZ_dt = np.gradient(df.Z, dt)
# Create a new dataframe with the Lorenz attractor variables and their derivatives
df= pd.DataFrame({
'X': df.X,
'Y': df.Y,
'Z': df.Z,
'dX_dt': dX_dt,
'dY_dt': dY_dt,
'dZ_dt': dZ_dt
})
df
| X | Y | Z | dX_dt | dY_dt | dZ_dt | |
|---|---|---|---|---|---|---|
| 0 | 0.000000 | 1.000000 | 1.050000 | 10.000000 | -1.000000 | -2.800350 |
| 1 | 0.100000 | 0.990000 | 1.021996 | 9.450000 | 0.353900 | -2.713507 |
| 2 | 0.189000 | 1.007078 | 0.995730 | 8.540390 | 2.902265 | -2.545969 |
| 3 | 0.270808 | 1.048045 | 0.971077 | 7.976577 | 5.184163 | -2.385659 |
| 4 | 0.348532 | 1.110761 | 0.948017 | 7.697336 | 7.294653 | -2.223634 |
| ... | ... | ... | ... | ... | ... | ... |
| 18247 | -1.866625 | -3.365286 | 19.796920 | -15.266930 | -10.943827 | -47.510113 |
| 18248 | -2.016491 | -3.484754 | 19.331754 | -14.834620 | -12.970734 | -45.523738 |
| 18249 | -2.163317 | -3.624700 | 18.886446 | -14.648231 | -15.042745 | -43.529795 |
| 18250 | -2.309455 | -3.785608 | 18.461158 | -14.687683 | -17.167364 | -41.510995 |
| 18251 | -2.457071 | -3.968048 | 18.056226 | -14.761531 | -18.243921 | -40.493215 |
18252 rows Ć 6 columns
Build DataLoaders¶
Three utilities are used to prepare the training data:
normalize_datascales each feature to a fixed range (either[0, 1]or[-1, 1]). Recurrent models (including LSTMs) are sensitive to feature scale because their gating mechanisms rely heavily ontanh/sigmoidactivations; unscaled inputs can lead to saturated activations and unstable gradients. In practice,[-1, 1]tends to work slightly better here.create_intervalsbuilds a sliding-window dataset. Withseq_length=366andpredict_days=1, it returns:
- an input sequence of length
365(seq_length - predict_days) - a target consisting of the next time step (
1step)
StockDatasetis a lightweighttorch.utils.data.Datasetwrapper that converts the NumPy arrays into tensors so they can be consumed by a PyTorchDataLoader.
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
def normalize_data(df):
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_data = scaler.fit_transform(df)
return scaled_data, scaler
def create_intervals(data, seq_length=366, predict_days=1):
sequences = []
targets = []
for i in range(len(data) - seq_length):
seq = data[i:i + seq_length]
input_seq = seq[: seq_length - predict_days]
target_seq = seq[seq_length - predict_days:]
sequences.append(input_seq)
targets.append(target_seq) # Assuming 'Close' is column 3
return np.array(sequences), np.array(targets)
#Custom Dataset class
class StockDataset(Dataset):
def __init__(self, sequences, targets):
self.sequences = sequences
self.targets = targets
def __len__(self):
return len(self.sequences)
def __getitem__(self, index):
return torch.tensor(self.sequences[index], dtype=torch.float32), torch.tensor(self.targets[index], dtype=torch.float32)
# Normalize the data
scaled_data, scaler = normalize_data(df.values)
# Create intervals of 90 days
sequences, targets = create_intervals(scaled_data)
# Train-Test Split
train_size = int(len(sequences) * 0.8) #outputs a single int
test_size = len(sequences) - train_size
#part of the training set, each train sequence of 200 is matched with a target of next 160 days
train_sequences = sequences[:train_size]
train_targets = targets[:train_size]
#part of the test set, each train sequence of 200 is matched with a target of next 160 days
test_sequences = sequences[train_size:]
test_targets = targets[train_size:]
# Step 5: Create DataLoaders
train_dataset = StockDataset(train_sequences, train_targets)
test_dataset = StockDataset(test_sequences, test_targets)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
Each batch produced by the DataLoaders contains:
inputswith shape[128, 365, 6]=[batch, time, features]targetswith shape[128, 1, 6]=[batch, horizon, features]
for x,i in train_loader:
print (x.shape, i.shape)
torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([128, 365, 6]) torch.Size([128, 1, 6]) torch.Size([100, 365, 6]) torch.Size([100, 1, 6])
Define the model¶
The model consists of stacked LSTM layers followed by a small feed-forward head:
LSTM backbone. In the forward pass, the hidden state and cell state are initialized and passed together with the input sequence. From the LSTM outputs, we use the final hidden state of the top LSTM layer (
h_n[-1]), which summarizes the entire input window.MLP prediction head. The final hidden state is mapped to the 6-dimensional output (next-step $x, y, z, \dot{x}, \dot{y}, \dot{z}$) via several fully connected layers. Empirically, a shallow head underfits, while adding many more layers provides limited benefit.
Dropout and activations. Dropout is used for regularization and to improve stability during training. The nonlinearity is SELU, which worked comparably to
tanhin this setting;ReLUtended to produce overly piecewise predictions.
for input, targets in train_loader:
print (input[:, -1, :3].shape)
torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([128, 3]) torch.Size([100, 3])
"""import torch
import torch.nn as nn
class stocklstm(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, seq_length=366, predict_days=1):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.seq_length = seq_length
self.predict_days = predict_days * 6
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True, dropout=0.4) #lstm
self.fc_1 = nn.Linear(hidden_size, hidden_size*2) #fully connected 1
self.fc_2 = nn.Linear(hidden_size*2, hidden_size*4) #fully connected 1
self.fc_3 = nn.Linear(hidden_size*4, hidden_size*2) #fully connected 1
self.fc_4 = nn.Linear(hidden_size*2, hidden_size) #fully connected 1
# Fully connected layers to map hidden state to output predictions
self.fc_out = nn.Linear(hidden_size, self.predict_days)
self.dropout= nn.Dropout(p=0.5)
self.selu = nn.SELU()
def forward(self,x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
output, (h_n, c_n) = self.lstm(x, (h0, c0)) # LSTM outputs all hidden states
out = h_n[-1] # Take the output of the last time step, the last h_n of last layer, size [128,128]
out = self.selu(self.fc_1(out)) # Activation after first FC
out = self.dropout(out)
out = self.selu(self.fc_2(out)) # Activation after first FC
out = self.dropout(out)
out = self.selu(self.fc_3(out)) # Activation after first FC
out = self.dropout(out)
out = self.selu(self.fc_4(out)) # Activation after first FC
out = self.dropout(out)
out = self.fc_out(out) # Final output layer without activation (for regression)
#print (out.shape)
#out = out.reshape(-1,self.predict_days / 3,3)
return out
"""
'import torch\nimport torch.nn as nn\n\nclass stocklstm(nn.Module):\n def __init__(self, input_size, hidden_size, num_layers, seq_length=366, predict_days=1):\n super().__init__()\n self.num_layers = num_layers\n self.hidden_size = hidden_size\n self.seq_length = seq_length\n self.predict_days = predict_days * 6\n\n self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,\n num_layers=num_layers, batch_first=True, dropout=0.4) #lstm\n\n self.fc_1 = nn.Linear(hidden_size, hidden_size*2) #fully connected 1\n self.fc_2 = nn.Linear(hidden_size*2, hidden_size*4) #fully connected 1\n self.fc_3 = nn.Linear(hidden_size*4, hidden_size*2) #fully connected 1\n self.fc_4 = nn.Linear(hidden_size*2, hidden_size) #fully connected 1\n # Fully connected layers to map hidden state to output predictions\n self.fc_out = nn.Linear(hidden_size, self.predict_days)\n\n\n self.dropout= nn.Dropout(p=0.5)\n self.selu = nn.SELU()\n\n def forward(self,x):\n h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)\n c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)\n\n output, (h_n, c_n) = self.lstm(x, (h0, c0)) # LSTM outputs all hidden states\n out = h_n[-1] # Take the output of the last time step, the last h_n of last layer, size [128,128]\n\n\n out = self.selu(self.fc_1(out)) # Activation after first FC\n out = self.dropout(out)\n\n out = self.selu(self.fc_2(out)) # Activation after first FC\n out = self.dropout(out)\n\n out = self.selu(self.fc_3(out)) # Activation after first FC\n out = self.dropout(out)\n \n out = self.selu(self.fc_4(out)) # Activation after first FC\n out = self.dropout(out)\n\n out = self.fc_out(out) # Final output layer without activation (for regression)\n #print (out.shape)\n #out = out.reshape(-1,self.predict_days / 3,3)\n return out\n\n\n'
import torch
import torch.nn as nn
class lorentzlstm(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, seq_length=366, predict_days=1):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.seq_length = seq_length
self.predict_days = predict_days * 6
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True, dropout=0.4) #lstm
self.fc_1 = nn.Linear(hidden_size, hidden_size*2) #fully connected 1
self.fc_2 = nn.Linear(hidden_size*2, hidden_size*4) #fully connected 1
self.fc_3 = nn.Linear(hidden_size*4, hidden_size*2) #fully connected 1
self.fc_4 = nn.Linear(hidden_size*2, hidden_size) #fully connected 1
# Fully connected layers to map hidden state to output predictions
self.fc_out = nn.Linear(hidden_size, self.predict_days)
self.dropout= nn.Dropout(p=0.5)
self.selu = nn.SELU()
def forward(self,x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
output, (h_n, c_n) = self.lstm(x, (h0, c0)) # LSTM outputs all hidden states
out = h_n[-1] # Take the output of the last time step, the last h_n of last layer, size [128,128]
out = self.dropout(self.fc_1(out)) # Activation after first FC
out = self.selu(out)
out = self.dropout(self.fc_2(out)) # Activation after first FC
out = self.selu(out)
out = self.dropout(self.fc_3(out)) # Activation after first FC
out = self.selu(out)
out = self.dropout(self.fc_4(out)) # Activation after first FC
out = self.selu(out)
out = self.fc_out(out) # Final output layer without activation (for regression)
#print (out.shape)
#out = out.reshape(-1,self.predict_days / 3,3)
"""
gradient = (out[:, :3] - x[:, -1, :3]) /dt
out = torch.concat((out[:, :3],gradient), 1)
"""
return out
Training configuration notes:
Model size. The model is instantiated with
input_size=6,hidden_size=128, andnum_layers=9. While many references suggest that 1ā3 LSTM layers are often sufficient, increasing depth improved performance for this specific task (likely because the dynamics are highly nonlinear). Further improvements may also come from alternatives such as larger hidden size, layer normalization, or architectural changes rather than adding more layers.Loss function. Both
MSELossandHuberLosswere tested. Since this dataset is deterministic and essentially noise-free,MSELossis a natural choice; in practice the two behaved similarly.Optimizer. Several optimizers were evaluated (AdamW, Adam, Adagrad, SGD, RMSprop). RMSprop performed best in this notebook with
lr=1e-4andmomentum=0.9.
device = "cuda"
lstm = lorentzlstm(6,128,9).to("cuda")
criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(lstm.parameters(), lr=0.0001, momentum=0.9)
lstm
lorenzlstm( (lstm): LSTM(6, 128, num_layers=9, batch_first=True, dropout=0.4) (fc_1): Linear(in_features=128, out_features=256, bias=True) (fc_2): Linear(in_features=256, out_features=512, bias=True) (fc_3): Linear(in_features=512, out_features=256, bias=True) (fc_4): Linear(in_features=256, out_features=128, bias=True) (fc_out): Linear(in_features=128, out_features=6, bias=True) (dropout): Dropout(p=0.5, inplace=False) (selu): SELU() )
for input, targets in train_loader:
input= input.to(device)
print(lstm(input))
break
tensor([[ 4.5706e-01, 5.9219e-01, -3.2413e-01, 7.2902e-02, -3.9372e-01,
1.0854e+01],
[-2.3636e-01, -2.9977e-01, -7.2791e-01, -1.3338e+01, -1.5106e+01,
1.2457e+01],
[-6.9575e-01, -3.4633e-01, 3.1442e-01, -8.4478e+00, -2.6035e+00,
-1.3429e-02],
[-2.6344e-01, -2.0738e-01, -1.5209e-01, -3.0899e+00, -1.3704e+00,
5.3749e+00],
[ 5.9867e-02, 1.0903e-01, -5.4822e-01, -4.9003e+00, -3.1270e+00,
2.1530e+00],
[-4.2312e-01, -2.9975e-01, -5.9720e-04, 4.4988e+00, 6.7966e+00,
-9.9345e-01],
[ 1.1748e-01, 9.9251e-02, -2.4118e-01, -3.7457e+00, 1.6559e-01,
-2.4794e+00],
[-5.6432e-01, -5.6410e-01, -1.0804e-03, -3.1674e+00, -6.1743e+00,
4.4929e+00],
[ 1.3977e-01, 1.7970e-01, -7.3551e-01, -5.9584e+00, -9.7566e+00,
-8.3475e+00],
[ 2.7093e-01, 2.0480e-01, -1.2517e-01, -4.6318e+00, -6.9207e+00,
3.9184e+00],
[ 5.8506e-02, 2.4120e-02, -2.1770e-01, -7.7911e+00, -8.4149e+00,
7.6862e+00],
[-5.5449e-01, -3.5391e-01, 5.2127e-02, -1.6009e+01, -8.7792e+00,
4.6942e+00],
[-1.3223e-01, 1.1460e-02, -5.3280e-03, 4.7991e+00, -2.1139e+00,
-4.8492e+00],
[ 4.4891e-01, 3.8311e-01, 2.4625e-02, -2.5773e+00, -7.6111e+00,
1.2039e+01],
[ 4.3122e-02, 1.3942e-01, -5.2556e-01, -1.7972e+01, -1.6105e+01,
7.3489e+00],
[-1.8516e-01, -2.4053e-01, -4.9439e-01, -9.6922e+00, -1.3543e+01,
-1.3873e+00],
[ 2.1447e-01, 2.0735e-01, -1.1855e-01, -1.0805e+01, -6.9515e+00,
3.8156e+00],
[-4.0323e-01, -3.0041e-01, 1.3249e-02, 5.4121e+00, 6.5406e+00,
2.2687e+00],
[ 2.0345e-01, -3.1579e-02, 1.4339e-01, -8.9741e+00, 3.6023e+00,
-6.9705e+00],
[-8.8486e-01, -8.1751e-01, 4.0597e-02, -1.3695e+01, -3.0540e+00,
-1.6027e+00],
[-3.3790e-01, 8.0693e-02, 2.6635e-01, -1.2759e+01, -6.8783e+00,
8.0309e+00],
[-2.7190e-02, 4.2646e-02, -4.5539e-01, -4.1125e+00, 4.2977e+00,
-3.6577e+00],
[ 5.3388e-01, 2.4911e-01, 1.8353e-01, 1.7844e+01, 5.7405e+00,
1.0346e+01],
[-2.5825e-01, -1.7355e-01, 6.3927e-02, 1.3895e+01, 3.5734e+00,
-1.5612e+00],
[-4.8680e-01, -3.4559e-01, 1.6527e-02, -1.2544e+00, -6.8753e-01,
-4.2065e+00],
[ 2.1386e-01, -9.6538e-02, 2.1728e-01, -4.8549e+00, -4.2956e+00,
5.6135e+00],
[ 2.0297e-01, 1.0255e-01, -4.5017e-02, -1.6142e+01, -1.7096e+01,
1.2138e+00],
[ 2.3929e-01, 3.5316e-01, -3.7058e-01, -6.7016e+00, 8.8146e-02,
1.7038e-01],
[ 2.1552e-01, 3.0268e-01, -5.6545e-01, -3.7417e+00, -2.6029e+00,
-2.9043e+00],
[ 6.6325e-01, 3.2544e-01, 1.7666e-01, 2.1654e+01, 1.6419e+01,
-4.6540e+00],
[ 3.7139e-01, -1.8916e-02, 2.8381e-01, -7.8105e+00, 1.7753e+01,
-2.1573e+01],
[ 2.3600e-01, 2.3274e-01, -1.6113e-01, -1.1591e+01, -8.1705e+00,
-8.8523e-01],
[ 5.7741e-01, 6.3892e-01, -2.1348e-01, 1.8281e+01, 1.7859e+01,
1.1580e+01],
[ 4.7881e-01, 1.4762e-01, 2.0431e-01, 8.5816e+00, 3.3745e+00,
1.8798e+00],
[-7.3636e-01, -6.8643e-01, -7.0334e-02, -2.2276e+01, -1.3740e+01,
9.7507e+00],
[-2.4289e-01, -2.0481e-01, -1.8683e-01, -8.0444e-01, 7.8556e-01,
5.4584e+00],
[-3.6176e-01, -2.1137e-01, 2.2190e-02, 1.6457e+00, 3.7715e+00,
2.3722e+00],
[-7.2457e-01, -4.1920e-01, 3.1168e-01, -1.1235e+01, -5.4377e+00,
3.7685e+00],
[-5.3189e-01, -6.1972e-01, -3.6150e-01, 9.4366e-02, 2.5083e+00,
-5.6360e+00],
[-4.3836e-02, -8.2365e-02, -1.9024e-01, -7.3975e+00, -2.1166e+00,
-2.4645e+00],
[-6.1709e-01, -4.9126e-01, 3.9841e-01, 2.0277e+00, 2.9145e+00,
2.4451e+01],
[-4.3709e-02, 4.5394e-02, -4.1511e-02, -1.3456e+00, -3.8715e+00,
3.4252e+00],
[ 2.1895e-01, -2.6684e-02, 1.6973e-01, -2.3671e+01, -8.3290e+00,
-1.5996e+01],
[ 4.6162e-01, 3.2359e-01, -3.7117e-03, 1.8961e+00, 6.2198e-01,
-4.7307e+00],
[-1.7750e-01, -2.1658e-01, -4.1395e-01, -2.2070e+00, -3.9322e+00,
1.2013e+00],
[-6.5496e-01, -4.9381e-01, 3.1155e-01, 5.3702e+00, 1.3069e+01,
1.3449e+01],
[ 8.5956e-01, 6.9507e-01, 1.9602e-01, 1.7075e+01, 8.5904e+00,
4.1940e-01],
[ 2.0063e-01, 1.9193e-01, -2.4913e-01, 2.2939e+00, 5.1831e-01,
1.5525e-01],
[-4.5837e-01, -3.6277e-01, -5.0755e-03, 3.1728e+00, -2.5903e+00,
-9.4840e+00],
[-8.7084e-01, -3.8919e-01, 6.6428e-01, -1.3449e+01, 1.2413e+01,
3.1175e+01],
[-3.4278e-01, -3.4414e-01, -1.8231e-01, 2.5199e+00, 3.6291e+00,
6.0332e+00],
[-3.5112e-01, -3.2888e-01, -5.5591e-02, 1.1614e+01, 7.7847e+00,
-7.9130e-01],
[-3.1043e-01, -2.7639e-01, -5.4530e-02, -7.0320e+00, 3.8980e+00,
1.5576e-01],
[-6.3635e-01, -3.2490e-01, 1.0898e-01, -4.0231e+00, 3.6474e+00,
-1.4063e+01],
[-3.3592e-01, -3.2114e-01, -6.6361e-02, 1.1882e+01, 3.9980e+00,
-5.8212e+00],
[-8.8912e-02, -1.1174e-01, -7.1589e-01, -5.0879e+00, -6.2174e+00,
-1.0604e+01],
[-3.8817e-01, -3.3325e-01, -3.3812e-02, 9.0600e+00, 9.6577e+00,
1.3311e+00],
[-4.5328e-01, -1.6877e-01, 2.7180e-01, 1.1551e+01, 9.1798e-01,
-9.8503e+00],
[ 4.9234e-01, 3.2229e-01, 1.0149e-01, -8.7148e+00, -2.4012e+00,
-1.6358e+01],
[-3.4939e-01, -3.6913e-01, -1.9450e-01, -3.5491e+00, -2.7204e+00,
7.0649e+00],
[ 3.5069e-01, 3.4882e-01, -7.3911e-02, -9.0740e+00, -3.4802e+00,
-1.8687e+00],
[-6.9478e-02, -2.1910e-01, 1.3125e-02, -1.1725e+01, -3.9037e+00,
-1.9804e+00],
[ 5.5236e-01, 4.6620e-01, 3.5994e-02, 1.4394e+01, 8.6059e+00,
1.5431e+01],
[-5.9842e-01, -5.5735e-01, -1.1789e-01, -1.5240e+01, -1.4716e+01,
-1.7677e+00],
[ 2.5562e-01, 2.3248e-01, -1.6493e-01, -2.8526e+00, -5.0240e+00,
1.0697e+01],
[-2.8930e-01, -1.7796e-01, -6.2743e-02, 5.6372e+00, -3.0559e-01,
-9.9510e+00],
[-3.8307e-01, -5.8047e-01, -6.0355e-01, 1.0151e+01, 7.2144e+00,
-9.3873e+00],
[-3.5098e-01, -3.1085e-01, -1.2332e-01, 5.6004e+00, 7.1535e-01,
-8.0995e+00],
[ 5.9105e-01, 3.5145e-01, 2.7617e-01, 4.9226e+00, 4.5452e-01,
8.0378e+00],
[-5.0977e-01, -3.4236e-01, 1.2570e-01, 2.1400e+00, 2.2076e-01,
-2.8986e+00],
[-5.1198e-01, -3.7433e-01, 1.3594e-01, 7.3759e+00, 3.5824e+00,
-4.4092e+00],
[ 2.3821e-01, 1.0805e-01, -5.8209e-04, -2.5514e+00, -8.4299e-02,
1.6109e+00],
[ 3.6005e-01, 3.0778e-01, 3.4564e-02, -1.5615e+01, -7.9025e+00,
-6.4659e+00],
[-5.9287e-01, -4.5682e-01, 1.7896e-01, -2.7328e+00, -5.1657e+00,
2.9717e+00],
[-2.9207e-02, -1.9712e-02, -3.3435e-01, -6.1084e+00, -5.8978e+00,
7.2281e+00],
[ 2.3829e-02, -1.6774e-01, -4.4813e-02, 4.2294e-01, -2.3801e-01,
-2.8119e+00],
[ 2.0444e-01, 6.6389e-02, -6.1231e-03, -9.2312e+00, -4.5491e+00,
-4.1395e+00],
[ 4.7536e-01, 4.6242e-01, 4.3502e-02, -3.2223e+00, -7.6268e+00,
1.8817e+01],
[ 2.5064e-01, 1.7238e-01, -3.3934e-02, -2.3698e+00, -3.3705e-01,
5.9549e+00],
[-3.3157e-01, -3.2817e-01, -1.7639e-01, -6.2711e+00, -4.4571e+00,
1.3929e+01],
[-1.6412e-01, -2.4306e-01, -3.9121e-01, 1.2604e+00, -4.1189e+00,
2.9443e+00],
[ 2.4618e-01, 9.8893e-02, -4.3148e-03, -1.0886e+01, -7.3593e+00,
-6.4842e+00],
[ 2.3558e-02, 6.4761e-02, -8.0237e-02, -1.0254e+00, -3.3013e+00,
6.5812e+00],
[-1.8092e-01, -1.9998e-01, -2.9721e-01, 5.3206e-01, 2.3636e+00,
9.4154e+00],
[-1.3522e-01, -1.9182e-01, -2.0662e-01, -6.3542e+00, -5.2282e+00,
-3.3052e-01],
[ 5.2101e-01, 5.2484e-01, -2.3666e-01, -2.4632e+00, -8.4696e+00,
-7.3449e+00],
[-5.4964e-01, -4.7247e-01, 1.6707e-02, 9.1840e-01, -1.4307e+00,
-5.4837e+00],
[-1.1385e+00, -1.0748e+00, 3.6275e-02, -4.0346e+01, -2.6390e+01,
4.7576e+00],
[-5.4056e-01, -3.8437e-01, 4.4346e-02, -1.0760e-01, 9.5299e-01,
-6.9680e+00],
[-1.9074e-01, -1.7497e-02, 9.4114e-02, -5.0622e+00, -6.2583e+00,
8.5333e+00],
[-1.1209e-01, -1.0318e-01, -4.6016e-01, -4.0158e-01, 2.4355e+00,
8.1435e-01],
[ 4.4195e-01, 4.9983e-02, 2.2555e-01, -1.0621e+00, -8.6646e+00,
-2.9927e+00],
[-7.3296e-01, -3.7874e-01, 3.5006e-01, 8.7810e-01, -7.9118e+00,
-1.9583e+01],
[-3.6962e-01, -2.4493e-01, -4.8879e-02, 1.4432e+00, -1.1767e+00,
-7.7979e+00],
[-5.4574e-01, -6.6280e-01, -1.5804e-01, 3.4485e+00, 2.2084e+00,
7.7131e+00],
[-1.4294e-01, -2.8398e-01, 5.6203e-03, -5.7029e+00, -1.6421e+00,
2.4155e-01],
[ 3.8756e-01, 2.9761e-01, 2.7136e-02, -7.9143e+00, -2.2213e+00,
-5.9150e+00],
[-4.3166e-01, -3.3144e-01, 1.4778e-01, 6.9083e-01, -4.5419e+00,
8.8889e+00],
[ 7.4095e-02, 9.6721e-02, -9.0164e-01, 2.1012e-01, -1.0984e+00,
-1.4211e+01],
[-5.1086e-02, 2.1163e-01, 1.6787e-01, -6.8572e+00, -1.4064e+00,
1.0911e+01],
[ 3.0154e-01, 1.6241e-01, 1.5174e-02, 1.6737e+00, 5.2614e+00,
-2.8253e-01],
[ 8.2974e-02, -2.7078e-01, 1.6944e-01, -1.6476e+01, 2.2930e+00,
-1.6186e+01],
[-3.5574e-01, -1.9818e-01, 2.0885e-01, 1.8032e+01, 4.1685e+00,
-5.8657e+00],
[ 3.2363e-01, 4.0125e-01, -2.1830e-01, -6.6348e-01, 4.9859e+00,
5.5418e+00],
[ 4.6960e-01, 2.2421e-01, 1.6993e-01, 1.5223e+00, 5.2906e+00,
-5.5762e+00],
[-4.3706e-01, -2.8690e-01, -4.9908e-03, -5.5123e+00, -5.8009e+00,
-3.5803e+00],
[-4.1920e-02, -3.4013e-02, -7.4290e-01, -3.9132e+00, -3.2754e+00,
-2.6803e+00],
[-5.9298e-01, -3.5351e-01, 2.8181e-01, 2.2947e+01, 3.4818e+01,
-4.9614e+00],
[-3.1012e-01, -3.1927e-01, -4.8845e-01, -8.9364e+00, -4.7594e+00,
-6.0642e-02],
[-3.6832e-01, 1.9798e-01, 3.2501e-01, 8.9770e+00, 5.2406e+00,
-1.2633e+01],
[ 2.0473e-01, -2.9946e-02, 1.9027e-01, -2.1353e-01, 7.0200e+00,
5.5199e+00],
[ 3.7670e-01, 3.8886e-01, 8.7231e-03, -1.2324e+01, -7.5578e+00,
4.9178e+00],
[-2.9358e-01, -3.0313e-01, -5.0451e-01, -1.2959e+01, -9.9090e+00,
-3.2896e+00],
[ 4.8542e-01, 5.9339e-01, -2.8306e-01, -1.4257e+00, 3.4923e-01,
-1.8165e+00],
[-5.1524e-02, -1.9421e-02, -5.1417e-01, -1.1775e+00, 2.4988e+00,
-6.5986e+00],
[-6.4935e-02, 1.1519e-01, 1.1128e-01, 1.0456e+00, -7.8174e+00,
1.1358e+00],
[-4.6985e-01, -3.3564e-01, 3.8050e-02, 1.1485e+01, 6.6417e+00,
-1.4896e+01],
[ 1.5784e-02, 1.8734e-03, -4.6540e-01, -6.4880e+00, -9.3138e+00,
-7.9960e-01],
[-3.7958e-01, -2.8853e-01, -2.0593e-02, 3.2070e+00, 1.4624e+00,
-8.2976e-01],
[-3.4538e-01, -2.5141e-01, -6.3491e-02, 6.5816e+00, 4.1525e+00,
-6.4046e+00],
[-2.9733e-01, -3.7599e-01, -3.1068e-01, 3.1308e+00, -1.9652e+00,
-3.8378e-01],
[ 4.6185e-01, 1.1708e-01, 3.2040e-01, -6.4888e+00, 3.6486e+00,
-9.4234e+00],
[-1.9995e-01, -2.2326e-01, -2.0695e-01, -7.7680e-01, -3.4087e+00,
1.0273e+01],
[-4.2778e-01, -3.1557e-01, 4.4745e-02, 1.8006e+00, 2.1934e+00,
3.8276e+00],
[-4.0893e-01, -3.1460e-01, -7.9684e-03, 3.5355e+00, 3.9236e+00,
1.0755e+00],
[ 2.0062e-01, 2.0944e-01, -1.6319e-01, -1.0380e+01, -1.4709e+00,
-5.5313e+00],
[ 2.1377e-01, -4.9791e-02, 1.5049e-01, -1.4535e+01, -2.4302e+00,
-1.1280e+01],
[-2.2736e-01, -2.2473e-01, -9.3776e-02, 1.2261e+01, 8.2166e+00,
5.1882e+00]], device='cuda:0', grad_fn=<CatBackward0>)
from torchviz import make_dot
x = torch.randn(1, 365 , 6).to(device) # Replace with appropriate input size
output = lstm(x)
# Visualize the model
dot = make_dot(output, params=dict(lstm.named_parameters()))
dot.render("lorentz_attractor_model_architecture", format="svg")
'lorentz_attractor_model_architecture.svg'
Train the model¶
# Training loop
clip_value = 0.8
num_epochs = 20
for epoch in range(num_epochs):
lstm.train()
train_loss = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
# Forward pass
outputs = lstm(inputs)
#print (outputs.shape, targets.squeeze().shape )
loss = criterion(outputs, targets.squeeze())
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(lstm.parameters(), clip_value)
optimizer.step()
train_loss += loss.item()
# Average train loss for this epoch
train_loss /= len(train_loader)
# Validation step (on test set)
lstm.eval()
test_loss = 0
with torch.inference_mode():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs =lstm(inputs)
loss = criterion(outputs, targets.squeeze())
test_loss += loss.item()
# Average test loss for this epoch
test_loss /= len(test_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
Epoch [1/20], Train Loss: 33.2310, Test Loss: 5.9038 Epoch [2/20], Train Loss: 30.7409, Test Loss: 1.3986 Epoch [3/20], Train Loss: 29.4363, Test Loss: 2.3024 Epoch [4/20], Train Loss: 28.6587, Test Loss: 0.8585 Epoch [5/20], Train Loss: 28.5013, Test Loss: 0.9563 Epoch [6/20], Train Loss: 28.9173, Test Loss: 1.8390 Epoch [7/20], Train Loss: 28.6386, Test Loss: 1.7320 Epoch [8/20], Train Loss: 27.7666, Test Loss: 3.1902 Epoch [9/20], Train Loss: 28.1951, Test Loss: 3.6165 Epoch [10/20], Train Loss: 28.2287, Test Loss: 1.5740 Epoch [11/20], Train Loss: 27.6656, Test Loss: 1.2109 Epoch [12/20], Train Loss: 27.6725, Test Loss: 1.8751 Epoch [13/20], Train Loss: 27.4233, Test Loss: 1.7084 Epoch [14/20], Train Loss: 26.8021, Test Loss: 1.8940 Epoch [15/20], Train Loss: 26.6064, Test Loss: 2.1320 Epoch [16/20], Train Loss: 26.6971, Test Loss: 3.1355 Epoch [17/20], Train Loss: 26.4950, Test Loss: 2.1259 Epoch [18/20], Train Loss: 26.3924, Test Loss: 1.9396 Epoch [19/20], Train Loss: 26.5422, Test Loss: 2.8664
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[58], line 22 19 nn.utils.clip_grad_value_(lstm.parameters(), clip_value) 20 optimizer.step() ---> 22 train_loss += loss.item() 24 # Average train loss for this epoch 25 train_loss /= len(train_loader) KeyboardInterrupt:
"""# Training loop
clip_value = 0.5
num_epochs = 20
for epoch in range(num_epochs):
lstm.train()
train_loss = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
# Forward pass
output = lstm(inputs)
dx_dt = (output[2:, 0] - output[:-2, 0]) / (2 * dt)
dy_dt = (output[2:, 1] - output[:-2, 1]) / (2 * dt)
dz_dt = (output[2:, 2] - output[:-2, 2]) / (2 * dt)
true_dx_dt = targets.squeeze()[1:-1,3]
true_dy_dt = targets.squeeze()[1:-1,4]
true_dz_dt = targets.squeeze()[1:-1,5]
loss_positions = criterion(output, targets.squeeze()[:,:3]) # Loss between predicted and true x, y, z
loss_dx = criterion(dx_dt, true_dx_dt)/3
loss_dy = criterion(dy_dt, true_dy_dt)/3
loss_dz = criterion(dz_dt, true_dz_dt)/3
# Total loss: combining position and derivative losses
loss = loss_positions + loss_dx + loss_dy + loss_dz
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(lstm.parameters(), clip_value)
optimizer.step()
train_loss += loss.item()
# Average train loss for this epoch
train_loss /= len(train_loader)
# Validation step (on test set)
lstm.eval()
test_loss = 0
with torch.inference_mode():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
output =lstm(inputs)
loss = criterion(output, targets.squeeze()[:,:3])
test_loss += loss.item()
# Average test loss for this epoch
test_loss /= len(test_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
"""
Visualize results¶
The final section visualizes an autoregressive rollout on a selected test sequence.
plot_indexselects which test sequence to visualize.predict_amountcontrols how many future time steps are generated.
plot_index=200
predict_amount=365
A = torch.tensor(test_sequences[plot_index], dtype=torch.float32).unsqueeze(0).to(device)
A.shape
torch.Size([1, 365, 6])
Autoregressive inference is performed by repeatedly predicting the next step, appending it to the current sequence, and predicting again. As the horizon increases, small errors compound, so long rollouts are expected to become progressively less accurate.
B=A
lstm.eval()
with torch.inference_mode():
for i in range(predict_amount):
C = lstm(B).unsqueeze(0)
B = torch.concat ((B, C), 1)
print (B.shape)
A = A.cpu().numpy()
B = B.cpu().numpy()
torch.Size([1, 730, 6])
test_sequences.shape
(3578, 365, 6)
test_targets[plot_index:plot_index + predict_amount, 0, 0].shape
(365,)
#test_sequences is the input sequence
#B in input sequence + predicted things (take the last predict_amount)
#targets sequence is the targets, take the first predict_amount starting from plot index
ax = plt.axes(projection='3d')
#ax.view_init(0,90) #you can use this to move the view
ax.plot3D(test_sequences[plot_index, :, 0], test_sequences[plot_index, :, 1], test_sequences[plot_index, :, 2], 'red')
ax.plot3D(B[0, -predict_amount:, 0], B[0, -predict_amount:, 1], B[0, -predict_amount:, 2], 'blue')
ax.plot3D(test_targets[plot_index:plot_index + predict_amount, 0, 0], test_targets[plot_index:plot_index + predict_amount, 0, 1], test_targets[plot_index:plot_index + predict_amount, 0, 2], 'green')
[<mpl_toolkits.mplot3d.art3d.Line3D at 0x17d828dab10>]
#torch.save(lstm.state_dict(), "C:/Users/repea/Downloads/lorenz.pth")
#lstm.load_state_dict(torch.load("C:/Users/repea/Downloads/lorenz.pth"))
<All keys matched successfully>