Generate Lorenz attractor data¶

InĀ [13]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def lorenz_attractor(x, y, z, s=10, r=28, b=2.667): #physic s,r,b are physics parameter https://en.wikipedia.org/wiki/Lorenz_system
    dx = s * (y - x)
    dy = x * (r - z) - y
    dz = x * y - b * z
    return dx, dy, dz

# Time parameters
dt = 0.01  # Time step, difference in time between one point and another
num_steps = 18251  # Number of time steps

# Initial conditions
x = np.ones(num_steps + 1)
y = np.ones(num_steps + 1)
z = np.ones(num_steps + 1)

# Set initial values
x[0], y[0], z[0] = 0.0, 1.0, 1.05

# Generate Lorenz attractor data
for i in range(num_steps):
    dx, dy, dz = lorenz_attractor(x[i], y[i], z[i])
    x[i + 1] = x[i] + dx * dt
    y[i + 1] = y[i] + dy * dt
    z[i + 1] = z[i] + dz * dt

# Create a dataframe with the Lorenz attractor data
df= pd.DataFrame({'X': x, 'Y': y, 'Z': z})


ax = plt.axes(projection='3d')

# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(df.X, df.Y, df.Z, color='red', linewidth=0.2 )
df
Out[13]:
X Y Z
0 0.000000 1.000000 1.050000
1 0.100000 0.990000 1.021996
2 0.189000 1.007078 0.995730
3 0.270808 1.048045 0.971077
4 0.348532 1.110761 0.948017
... ... ... ...
18247 -1.866625 -3.365286 19.796920
18248 -2.016491 -3.484754 19.331754
18249 -2.163317 -3.624700 18.886446
18250 -2.309455 -3.785608 18.461158
18251 -2.457071 -3.968048 18.056226

18252 rows Ɨ 3 columns

No description has been provided for this image

To help the model capture the local dynamics of the trajectory, I include first-order derivatives as additional input features. Each time step is therefore represented by six values: $x$, $y$, $z$, $\dot{x}$, $\dot{y}$, and $\dot{z}$.

InĀ [14]:
dX_dt = np.gradient(df.X, dt)
dY_dt = np.gradient(df.Y, dt)
dZ_dt = np.gradient(df.Z, dt)

# Create a new dataframe with the Lorenz attractor variables and their derivatives
df= pd.DataFrame({
    'X': df.X,
    'Y': df.Y,
    'Z': df.Z,
    'dX_dt': dX_dt,
    'dY_dt': dY_dt,
    'dZ_dt': dZ_dt
})
df
Out[14]:
X Y Z dX_dt dY_dt dZ_dt
0 0.000000 1.000000 1.050000 10.000000 -1.000000 -2.800350
1 0.100000 0.990000 1.021996 9.450000 0.353900 -2.713507
2 0.189000 1.007078 0.995730 8.540390 2.902265 -2.545969
3 0.270808 1.048045 0.971077 7.976577 5.184163 -2.385659
4 0.348532 1.110761 0.948017 7.697336 7.294653 -2.223634
... ... ... ... ... ... ...
18247 -1.866625 -3.365286 19.796920 -15.266930 -10.943827 -47.510113
18248 -2.016491 -3.484754 19.331754 -14.834620 -12.970734 -45.523738
18249 -2.163317 -3.624700 18.886446 -14.648231 -15.042745 -43.529795
18250 -2.309455 -3.785608 18.461158 -14.687683 -17.167364 -41.510995
18251 -2.457071 -3.968048 18.056226 -14.761531 -18.243921 -40.493215

18252 rows Ɨ 6 columns

Build DataLoaders¶

Three utilities are used to prepare the training data:

  1. normalize_data scales each feature to a fixed range (either [0, 1] or [-1, 1]). Recurrent models (including LSTMs) are sensitive to feature scale because their gating mechanisms rely heavily on tanh/sigmoid activations; unscaled inputs can lead to saturated activations and unstable gradients. In practice, [-1, 1] tends to work slightly better here.

  2. create_intervals builds a sliding-window dataset. With seq_length=366 and predict_days=1, it returns:

  • an input sequence of length 365 (seq_length - predict_days)
  • a target consisting of the next time step (1 step)
  1. StockDataset is a lightweight torch.utils.data.Dataset wrapper that converts the NumPy arrays into tensors so they can be consumed by a PyTorch DataLoader.
InĀ [15]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler


def normalize_data(df):
    scaler = MinMaxScaler(feature_range=(-1, 1))
    scaled_data = scaler.fit_transform(df)
    return scaled_data, scaler


def create_intervals(data, seq_length=366, predict_days=1):
    sequences = []
    targets = []
    
    for i in range(len(data) - seq_length):
        seq = data[i:i + seq_length]
        input_seq = seq[: seq_length - predict_days]
        target_seq = seq[seq_length - predict_days:]
        
        sequences.append(input_seq)
        targets.append(target_seq)  # Assuming 'Close' is column 3
    
    return np.array(sequences), np.array(targets)

#Custom Dataset class
class StockDataset(Dataset):
    def __init__(self, sequences, targets):
        self.sequences = sequences
        self.targets = targets
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, index):
        return torch.tensor(self.sequences[index], dtype=torch.float32), torch.tensor(self.targets[index], dtype=torch.float32)
InĀ [16]:
# Normalize the data
scaled_data, scaler = normalize_data(df.values)

# Create intervals of 90 days
sequences, targets = create_intervals(scaled_data)
InĀ [17]:
#  Train-Test Split
train_size = int(len(sequences) * 0.8) #outputs a single int 
test_size = len(sequences) - train_size
#part of the training set, each train sequence of 200 is matched with a target of next 160 days
train_sequences = sequences[:train_size]
train_targets = targets[:train_size]
#part of the test set, each train sequence of 200 is matched with a target of next 160 days
test_sequences = sequences[train_size:]
test_targets = targets[train_size:]
InĀ [18]:
# Step 5: Create DataLoaders
train_dataset = StockDataset(train_sequences, train_targets)
test_dataset = StockDataset(test_sequences, test_targets)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Each batch produced by the DataLoaders contains:

  • inputs with shape [128, 365, 6] = [batch, time, features]
  • targets with shape [128, 1, 6] = [batch, horizon, features]
InĀ [19]:
for x,i in train_loader:
    print (x.shape, i.shape)
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([128, 365, 6]) torch.Size([128, 1, 6])
torch.Size([100, 365, 6]) torch.Size([100, 1, 6])

Define the model¶

The model consists of stacked LSTM layers followed by a small feed-forward head:

  1. LSTM backbone. In the forward pass, the hidden state and cell state are initialized and passed together with the input sequence. From the LSTM outputs, we use the final hidden state of the top LSTM layer (h_n[-1]), which summarizes the entire input window.

  2. MLP prediction head. The final hidden state is mapped to the 6-dimensional output (next-step $x, y, z, \dot{x}, \dot{y}, \dot{z}$) via several fully connected layers. Empirically, a shallow head underfits, while adding many more layers provides limited benefit.

  3. Dropout and activations. Dropout is used for regularization and to improve stability during training. The nonlinearity is SELU, which worked comparably to tanh in this setting; ReLU tended to produce overly piecewise predictions.

InĀ [20]:
for input, targets in train_loader:
    print (input[:, -1, :3].shape)
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([128, 3])
torch.Size([100, 3])
InĀ [21]:
"""import torch
import torch.nn as nn

class stocklstm(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, seq_length=366, predict_days=1):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.seq_length = seq_length
        self.predict_days = predict_days * 6

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                          num_layers=num_layers, batch_first=True, dropout=0.4) #lstm

        self.fc_1 =  nn.Linear(hidden_size, hidden_size*2) #fully connected 1
        self.fc_2 =  nn.Linear(hidden_size*2, hidden_size*4) #fully connected 1
        self.fc_3 =  nn.Linear(hidden_size*4, hidden_size*2) #fully connected 1
        self.fc_4 =  nn.Linear(hidden_size*2, hidden_size) #fully connected 1
        # Fully connected layers to map hidden state to output predictions
        self.fc_out = nn.Linear(hidden_size, self.predict_days)


        self.dropout= nn.Dropout(p=0.5)
        self.selu = nn.SELU()

    def forward(self,x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        output, (h_n, c_n) = self.lstm(x, (h0, c0))  # LSTM outputs all hidden states
        out = h_n[-1]  # Take the output of the last time step,  the last h_n of last layer, size [128,128]


        out = self.selu(self.fc_1(out))  # Activation after first FC
        out = self.dropout(out)

        out = self.selu(self.fc_2(out))  # Activation after first FC
        out = self.dropout(out)

        out = self.selu(self.fc_3(out))  # Activation after first FC
        out = self.dropout(out)
        
        out = self.selu(self.fc_4(out))  # Activation after first FC
        out = self.dropout(out)

        out = self.fc_out(out)  # Final output layer without activation (for regression)
        #print (out.shape)
        #out = out.reshape(-1,self.predict_days / 3,3)
        return out


"""
Out[21]:
'import torch\nimport torch.nn as nn\n\nclass stocklstm(nn.Module):\n    def __init__(self, input_size, hidden_size, num_layers, seq_length=366, predict_days=1):\n        super().__init__()\n        self.num_layers = num_layers\n        self.hidden_size = hidden_size\n        self.seq_length = seq_length\n        self.predict_days = predict_days * 6\n\n        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,\n                          num_layers=num_layers, batch_first=True, dropout=0.4) #lstm\n\n        self.fc_1 =  nn.Linear(hidden_size, hidden_size*2) #fully connected 1\n        self.fc_2 =  nn.Linear(hidden_size*2, hidden_size*4) #fully connected 1\n        self.fc_3 =  nn.Linear(hidden_size*4, hidden_size*2) #fully connected 1\n        self.fc_4 =  nn.Linear(hidden_size*2, hidden_size) #fully connected 1\n        # Fully connected layers to map hidden state to output predictions\n        self.fc_out = nn.Linear(hidden_size, self.predict_days)\n\n\n        self.dropout= nn.Dropout(p=0.5)\n        self.selu = nn.SELU()\n\n    def forward(self,x):\n        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)\n        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)\n\n        output, (h_n, c_n) = self.lstm(x, (h0, c0))  # LSTM outputs all hidden states\n        out = h_n[-1]  # Take the output of the last time step,  the last h_n of last layer, size [128,128]\n\n\n        out = self.selu(self.fc_1(out))  # Activation after first FC\n        out = self.dropout(out)\n\n        out = self.selu(self.fc_2(out))  # Activation after first FC\n        out = self.dropout(out)\n\n        out = self.selu(self.fc_3(out))  # Activation after first FC\n        out = self.dropout(out)\n        \n        out = self.selu(self.fc_4(out))  # Activation after first FC\n        out = self.dropout(out)\n\n        out = self.fc_out(out)  # Final output layer without activation (for regression)\n        #print (out.shape)\n        #out = out.reshape(-1,self.predict_days / 3,3)\n        return out\n\n\n'
InĀ [55]:
import torch
import torch.nn as nn

class lorentzlstm(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, seq_length=366, predict_days=1):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.seq_length = seq_length
        self.predict_days = predict_days * 6

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                          num_layers=num_layers, batch_first=True, dropout=0.4) #lstm

        self.fc_1 =  nn.Linear(hidden_size, hidden_size*2) #fully connected 1
        self.fc_2 =  nn.Linear(hidden_size*2, hidden_size*4) #fully connected 1
        self.fc_3 =  nn.Linear(hidden_size*4, hidden_size*2) #fully connected 1
        self.fc_4 =  nn.Linear(hidden_size*2, hidden_size) #fully connected 1
        # Fully connected layers to map hidden state to output predictions
        self.fc_out = nn.Linear(hidden_size, self.predict_days)



        self.dropout= nn.Dropout(p=0.5)
        self.selu = nn.SELU()

    def forward(self,x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        output, (h_n, c_n) = self.lstm(x, (h0, c0))  # LSTM outputs all hidden states
        out = h_n[-1]  # Take the output of the last time step,  the last h_n of last layer, size [128,128]


        out = self.dropout(self.fc_1(out))  # Activation after first FC
        out = self.selu(out)

        out = self.dropout(self.fc_2(out))  # Activation after first FC
        out = self.selu(out)

        out = self.dropout(self.fc_3(out))  # Activation after first FC
        out = self.selu(out)
        
        out = self.dropout(self.fc_4(out))  # Activation after first FC
        out = self.selu(out)

        out = self.fc_out(out)  # Final output layer without activation (for regression)
        #print (out.shape)
        #out = out.reshape(-1,self.predict_days / 3,3)

        """
        gradient = (out[:, :3] - x[:, -1, :3]) /dt
            

        out = torch.concat((out[:, :3],gradient), 1)
        """
        
        return out

Training configuration notes:

  1. Model size. The model is instantiated with input_size=6, hidden_size=128, and num_layers=9. While many references suggest that 1–3 LSTM layers are often sufficient, increasing depth improved performance for this specific task (likely because the dynamics are highly nonlinear). Further improvements may also come from alternatives such as larger hidden size, layer normalization, or architectural changes rather than adding more layers.

  2. Loss function. Both MSELoss and HuberLoss were tested. Since this dataset is deterministic and essentially noise-free, MSELoss is a natural choice; in practice the two behaved similarly.

  3. Optimizer. Several optimizers were evaluated (AdamW, Adam, Adagrad, SGD, RMSprop). RMSprop performed best in this notebook with lr=1e-4 and momentum=0.9 .

InĀ [56]:
device = "cuda"
lstm = lorentzlstm(6,128,9).to("cuda")
criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(lstm.parameters(), lr=0.0001, momentum=0.9)
InĀ [24]:
lstm
Out[24]:
lorenzlstm(
  (lstm): LSTM(6, 128, num_layers=9, batch_first=True, dropout=0.4)
  (fc_1): Linear(in_features=128, out_features=256, bias=True)
  (fc_2): Linear(in_features=256, out_features=512, bias=True)
  (fc_3): Linear(in_features=512, out_features=256, bias=True)
  (fc_4): Linear(in_features=256, out_features=128, bias=True)
  (fc_out): Linear(in_features=128, out_features=6, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (selu): SELU()
)
InĀ [54]:
for input, targets in train_loader:
    input= input.to(device)
    print(lstm(input))
    break
tensor([[ 4.5706e-01,  5.9219e-01, -3.2413e-01,  7.2902e-02, -3.9372e-01,
          1.0854e+01],
        [-2.3636e-01, -2.9977e-01, -7.2791e-01, -1.3338e+01, -1.5106e+01,
          1.2457e+01],
        [-6.9575e-01, -3.4633e-01,  3.1442e-01, -8.4478e+00, -2.6035e+00,
         -1.3429e-02],
        [-2.6344e-01, -2.0738e-01, -1.5209e-01, -3.0899e+00, -1.3704e+00,
          5.3749e+00],
        [ 5.9867e-02,  1.0903e-01, -5.4822e-01, -4.9003e+00, -3.1270e+00,
          2.1530e+00],
        [-4.2312e-01, -2.9975e-01, -5.9720e-04,  4.4988e+00,  6.7966e+00,
         -9.9345e-01],
        [ 1.1748e-01,  9.9251e-02, -2.4118e-01, -3.7457e+00,  1.6559e-01,
         -2.4794e+00],
        [-5.6432e-01, -5.6410e-01, -1.0804e-03, -3.1674e+00, -6.1743e+00,
          4.4929e+00],
        [ 1.3977e-01,  1.7970e-01, -7.3551e-01, -5.9584e+00, -9.7566e+00,
         -8.3475e+00],
        [ 2.7093e-01,  2.0480e-01, -1.2517e-01, -4.6318e+00, -6.9207e+00,
          3.9184e+00],
        [ 5.8506e-02,  2.4120e-02, -2.1770e-01, -7.7911e+00, -8.4149e+00,
          7.6862e+00],
        [-5.5449e-01, -3.5391e-01,  5.2127e-02, -1.6009e+01, -8.7792e+00,
          4.6942e+00],
        [-1.3223e-01,  1.1460e-02, -5.3280e-03,  4.7991e+00, -2.1139e+00,
         -4.8492e+00],
        [ 4.4891e-01,  3.8311e-01,  2.4625e-02, -2.5773e+00, -7.6111e+00,
          1.2039e+01],
        [ 4.3122e-02,  1.3942e-01, -5.2556e-01, -1.7972e+01, -1.6105e+01,
          7.3489e+00],
        [-1.8516e-01, -2.4053e-01, -4.9439e-01, -9.6922e+00, -1.3543e+01,
         -1.3873e+00],
        [ 2.1447e-01,  2.0735e-01, -1.1855e-01, -1.0805e+01, -6.9515e+00,
          3.8156e+00],
        [-4.0323e-01, -3.0041e-01,  1.3249e-02,  5.4121e+00,  6.5406e+00,
          2.2687e+00],
        [ 2.0345e-01, -3.1579e-02,  1.4339e-01, -8.9741e+00,  3.6023e+00,
         -6.9705e+00],
        [-8.8486e-01, -8.1751e-01,  4.0597e-02, -1.3695e+01, -3.0540e+00,
         -1.6027e+00],
        [-3.3790e-01,  8.0693e-02,  2.6635e-01, -1.2759e+01, -6.8783e+00,
          8.0309e+00],
        [-2.7190e-02,  4.2646e-02, -4.5539e-01, -4.1125e+00,  4.2977e+00,
         -3.6577e+00],
        [ 5.3388e-01,  2.4911e-01,  1.8353e-01,  1.7844e+01,  5.7405e+00,
          1.0346e+01],
        [-2.5825e-01, -1.7355e-01,  6.3927e-02,  1.3895e+01,  3.5734e+00,
         -1.5612e+00],
        [-4.8680e-01, -3.4559e-01,  1.6527e-02, -1.2544e+00, -6.8753e-01,
         -4.2065e+00],
        [ 2.1386e-01, -9.6538e-02,  2.1728e-01, -4.8549e+00, -4.2956e+00,
          5.6135e+00],
        [ 2.0297e-01,  1.0255e-01, -4.5017e-02, -1.6142e+01, -1.7096e+01,
          1.2138e+00],
        [ 2.3929e-01,  3.5316e-01, -3.7058e-01, -6.7016e+00,  8.8146e-02,
          1.7038e-01],
        [ 2.1552e-01,  3.0268e-01, -5.6545e-01, -3.7417e+00, -2.6029e+00,
         -2.9043e+00],
        [ 6.6325e-01,  3.2544e-01,  1.7666e-01,  2.1654e+01,  1.6419e+01,
         -4.6540e+00],
        [ 3.7139e-01, -1.8916e-02,  2.8381e-01, -7.8105e+00,  1.7753e+01,
         -2.1573e+01],
        [ 2.3600e-01,  2.3274e-01, -1.6113e-01, -1.1591e+01, -8.1705e+00,
         -8.8523e-01],
        [ 5.7741e-01,  6.3892e-01, -2.1348e-01,  1.8281e+01,  1.7859e+01,
          1.1580e+01],
        [ 4.7881e-01,  1.4762e-01,  2.0431e-01,  8.5816e+00,  3.3745e+00,
          1.8798e+00],
        [-7.3636e-01, -6.8643e-01, -7.0334e-02, -2.2276e+01, -1.3740e+01,
          9.7507e+00],
        [-2.4289e-01, -2.0481e-01, -1.8683e-01, -8.0444e-01,  7.8556e-01,
          5.4584e+00],
        [-3.6176e-01, -2.1137e-01,  2.2190e-02,  1.6457e+00,  3.7715e+00,
          2.3722e+00],
        [-7.2457e-01, -4.1920e-01,  3.1168e-01, -1.1235e+01, -5.4377e+00,
          3.7685e+00],
        [-5.3189e-01, -6.1972e-01, -3.6150e-01,  9.4366e-02,  2.5083e+00,
         -5.6360e+00],
        [-4.3836e-02, -8.2365e-02, -1.9024e-01, -7.3975e+00, -2.1166e+00,
         -2.4645e+00],
        [-6.1709e-01, -4.9126e-01,  3.9841e-01,  2.0277e+00,  2.9145e+00,
          2.4451e+01],
        [-4.3709e-02,  4.5394e-02, -4.1511e-02, -1.3456e+00, -3.8715e+00,
          3.4252e+00],
        [ 2.1895e-01, -2.6684e-02,  1.6973e-01, -2.3671e+01, -8.3290e+00,
         -1.5996e+01],
        [ 4.6162e-01,  3.2359e-01, -3.7117e-03,  1.8961e+00,  6.2198e-01,
         -4.7307e+00],
        [-1.7750e-01, -2.1658e-01, -4.1395e-01, -2.2070e+00, -3.9322e+00,
          1.2013e+00],
        [-6.5496e-01, -4.9381e-01,  3.1155e-01,  5.3702e+00,  1.3069e+01,
          1.3449e+01],
        [ 8.5956e-01,  6.9507e-01,  1.9602e-01,  1.7075e+01,  8.5904e+00,
          4.1940e-01],
        [ 2.0063e-01,  1.9193e-01, -2.4913e-01,  2.2939e+00,  5.1831e-01,
          1.5525e-01],
        [-4.5837e-01, -3.6277e-01, -5.0755e-03,  3.1728e+00, -2.5903e+00,
         -9.4840e+00],
        [-8.7084e-01, -3.8919e-01,  6.6428e-01, -1.3449e+01,  1.2413e+01,
          3.1175e+01],
        [-3.4278e-01, -3.4414e-01, -1.8231e-01,  2.5199e+00,  3.6291e+00,
          6.0332e+00],
        [-3.5112e-01, -3.2888e-01, -5.5591e-02,  1.1614e+01,  7.7847e+00,
         -7.9130e-01],
        [-3.1043e-01, -2.7639e-01, -5.4530e-02, -7.0320e+00,  3.8980e+00,
          1.5576e-01],
        [-6.3635e-01, -3.2490e-01,  1.0898e-01, -4.0231e+00,  3.6474e+00,
         -1.4063e+01],
        [-3.3592e-01, -3.2114e-01, -6.6361e-02,  1.1882e+01,  3.9980e+00,
         -5.8212e+00],
        [-8.8912e-02, -1.1174e-01, -7.1589e-01, -5.0879e+00, -6.2174e+00,
         -1.0604e+01],
        [-3.8817e-01, -3.3325e-01, -3.3812e-02,  9.0600e+00,  9.6577e+00,
          1.3311e+00],
        [-4.5328e-01, -1.6877e-01,  2.7180e-01,  1.1551e+01,  9.1798e-01,
         -9.8503e+00],
        [ 4.9234e-01,  3.2229e-01,  1.0149e-01, -8.7148e+00, -2.4012e+00,
         -1.6358e+01],
        [-3.4939e-01, -3.6913e-01, -1.9450e-01, -3.5491e+00, -2.7204e+00,
          7.0649e+00],
        [ 3.5069e-01,  3.4882e-01, -7.3911e-02, -9.0740e+00, -3.4802e+00,
         -1.8687e+00],
        [-6.9478e-02, -2.1910e-01,  1.3125e-02, -1.1725e+01, -3.9037e+00,
         -1.9804e+00],
        [ 5.5236e-01,  4.6620e-01,  3.5994e-02,  1.4394e+01,  8.6059e+00,
          1.5431e+01],
        [-5.9842e-01, -5.5735e-01, -1.1789e-01, -1.5240e+01, -1.4716e+01,
         -1.7677e+00],
        [ 2.5562e-01,  2.3248e-01, -1.6493e-01, -2.8526e+00, -5.0240e+00,
          1.0697e+01],
        [-2.8930e-01, -1.7796e-01, -6.2743e-02,  5.6372e+00, -3.0559e-01,
         -9.9510e+00],
        [-3.8307e-01, -5.8047e-01, -6.0355e-01,  1.0151e+01,  7.2144e+00,
         -9.3873e+00],
        [-3.5098e-01, -3.1085e-01, -1.2332e-01,  5.6004e+00,  7.1535e-01,
         -8.0995e+00],
        [ 5.9105e-01,  3.5145e-01,  2.7617e-01,  4.9226e+00,  4.5452e-01,
          8.0378e+00],
        [-5.0977e-01, -3.4236e-01,  1.2570e-01,  2.1400e+00,  2.2076e-01,
         -2.8986e+00],
        [-5.1198e-01, -3.7433e-01,  1.3594e-01,  7.3759e+00,  3.5824e+00,
         -4.4092e+00],
        [ 2.3821e-01,  1.0805e-01, -5.8209e-04, -2.5514e+00, -8.4299e-02,
          1.6109e+00],
        [ 3.6005e-01,  3.0778e-01,  3.4564e-02, -1.5615e+01, -7.9025e+00,
         -6.4659e+00],
        [-5.9287e-01, -4.5682e-01,  1.7896e-01, -2.7328e+00, -5.1657e+00,
          2.9717e+00],
        [-2.9207e-02, -1.9712e-02, -3.3435e-01, -6.1084e+00, -5.8978e+00,
          7.2281e+00],
        [ 2.3829e-02, -1.6774e-01, -4.4813e-02,  4.2294e-01, -2.3801e-01,
         -2.8119e+00],
        [ 2.0444e-01,  6.6389e-02, -6.1231e-03, -9.2312e+00, -4.5491e+00,
         -4.1395e+00],
        [ 4.7536e-01,  4.6242e-01,  4.3502e-02, -3.2223e+00, -7.6268e+00,
          1.8817e+01],
        [ 2.5064e-01,  1.7238e-01, -3.3934e-02, -2.3698e+00, -3.3705e-01,
          5.9549e+00],
        [-3.3157e-01, -3.2817e-01, -1.7639e-01, -6.2711e+00, -4.4571e+00,
          1.3929e+01],
        [-1.6412e-01, -2.4306e-01, -3.9121e-01,  1.2604e+00, -4.1189e+00,
          2.9443e+00],
        [ 2.4618e-01,  9.8893e-02, -4.3148e-03, -1.0886e+01, -7.3593e+00,
         -6.4842e+00],
        [ 2.3558e-02,  6.4761e-02, -8.0237e-02, -1.0254e+00, -3.3013e+00,
          6.5812e+00],
        [-1.8092e-01, -1.9998e-01, -2.9721e-01,  5.3206e-01,  2.3636e+00,
          9.4154e+00],
        [-1.3522e-01, -1.9182e-01, -2.0662e-01, -6.3542e+00, -5.2282e+00,
         -3.3052e-01],
        [ 5.2101e-01,  5.2484e-01, -2.3666e-01, -2.4632e+00, -8.4696e+00,
         -7.3449e+00],
        [-5.4964e-01, -4.7247e-01,  1.6707e-02,  9.1840e-01, -1.4307e+00,
         -5.4837e+00],
        [-1.1385e+00, -1.0748e+00,  3.6275e-02, -4.0346e+01, -2.6390e+01,
          4.7576e+00],
        [-5.4056e-01, -3.8437e-01,  4.4346e-02, -1.0760e-01,  9.5299e-01,
         -6.9680e+00],
        [-1.9074e-01, -1.7497e-02,  9.4114e-02, -5.0622e+00, -6.2583e+00,
          8.5333e+00],
        [-1.1209e-01, -1.0318e-01, -4.6016e-01, -4.0158e-01,  2.4355e+00,
          8.1435e-01],
        [ 4.4195e-01,  4.9983e-02,  2.2555e-01, -1.0621e+00, -8.6646e+00,
         -2.9927e+00],
        [-7.3296e-01, -3.7874e-01,  3.5006e-01,  8.7810e-01, -7.9118e+00,
         -1.9583e+01],
        [-3.6962e-01, -2.4493e-01, -4.8879e-02,  1.4432e+00, -1.1767e+00,
         -7.7979e+00],
        [-5.4574e-01, -6.6280e-01, -1.5804e-01,  3.4485e+00,  2.2084e+00,
          7.7131e+00],
        [-1.4294e-01, -2.8398e-01,  5.6203e-03, -5.7029e+00, -1.6421e+00,
          2.4155e-01],
        [ 3.8756e-01,  2.9761e-01,  2.7136e-02, -7.9143e+00, -2.2213e+00,
         -5.9150e+00],
        [-4.3166e-01, -3.3144e-01,  1.4778e-01,  6.9083e-01, -4.5419e+00,
          8.8889e+00],
        [ 7.4095e-02,  9.6721e-02, -9.0164e-01,  2.1012e-01, -1.0984e+00,
         -1.4211e+01],
        [-5.1086e-02,  2.1163e-01,  1.6787e-01, -6.8572e+00, -1.4064e+00,
          1.0911e+01],
        [ 3.0154e-01,  1.6241e-01,  1.5174e-02,  1.6737e+00,  5.2614e+00,
         -2.8253e-01],
        [ 8.2974e-02, -2.7078e-01,  1.6944e-01, -1.6476e+01,  2.2930e+00,
         -1.6186e+01],
        [-3.5574e-01, -1.9818e-01,  2.0885e-01,  1.8032e+01,  4.1685e+00,
         -5.8657e+00],
        [ 3.2363e-01,  4.0125e-01, -2.1830e-01, -6.6348e-01,  4.9859e+00,
          5.5418e+00],
        [ 4.6960e-01,  2.2421e-01,  1.6993e-01,  1.5223e+00,  5.2906e+00,
         -5.5762e+00],
        [-4.3706e-01, -2.8690e-01, -4.9908e-03, -5.5123e+00, -5.8009e+00,
         -3.5803e+00],
        [-4.1920e-02, -3.4013e-02, -7.4290e-01, -3.9132e+00, -3.2754e+00,
         -2.6803e+00],
        [-5.9298e-01, -3.5351e-01,  2.8181e-01,  2.2947e+01,  3.4818e+01,
         -4.9614e+00],
        [-3.1012e-01, -3.1927e-01, -4.8845e-01, -8.9364e+00, -4.7594e+00,
         -6.0642e-02],
        [-3.6832e-01,  1.9798e-01,  3.2501e-01,  8.9770e+00,  5.2406e+00,
         -1.2633e+01],
        [ 2.0473e-01, -2.9946e-02,  1.9027e-01, -2.1353e-01,  7.0200e+00,
          5.5199e+00],
        [ 3.7670e-01,  3.8886e-01,  8.7231e-03, -1.2324e+01, -7.5578e+00,
          4.9178e+00],
        [-2.9358e-01, -3.0313e-01, -5.0451e-01, -1.2959e+01, -9.9090e+00,
         -3.2896e+00],
        [ 4.8542e-01,  5.9339e-01, -2.8306e-01, -1.4257e+00,  3.4923e-01,
         -1.8165e+00],
        [-5.1524e-02, -1.9421e-02, -5.1417e-01, -1.1775e+00,  2.4988e+00,
         -6.5986e+00],
        [-6.4935e-02,  1.1519e-01,  1.1128e-01,  1.0456e+00, -7.8174e+00,
          1.1358e+00],
        [-4.6985e-01, -3.3564e-01,  3.8050e-02,  1.1485e+01,  6.6417e+00,
         -1.4896e+01],
        [ 1.5784e-02,  1.8734e-03, -4.6540e-01, -6.4880e+00, -9.3138e+00,
         -7.9960e-01],
        [-3.7958e-01, -2.8853e-01, -2.0593e-02,  3.2070e+00,  1.4624e+00,
         -8.2976e-01],
        [-3.4538e-01, -2.5141e-01, -6.3491e-02,  6.5816e+00,  4.1525e+00,
         -6.4046e+00],
        [-2.9733e-01, -3.7599e-01, -3.1068e-01,  3.1308e+00, -1.9652e+00,
         -3.8378e-01],
        [ 4.6185e-01,  1.1708e-01,  3.2040e-01, -6.4888e+00,  3.6486e+00,
         -9.4234e+00],
        [-1.9995e-01, -2.2326e-01, -2.0695e-01, -7.7680e-01, -3.4087e+00,
          1.0273e+01],
        [-4.2778e-01, -3.1557e-01,  4.4745e-02,  1.8006e+00,  2.1934e+00,
          3.8276e+00],
        [-4.0893e-01, -3.1460e-01, -7.9684e-03,  3.5355e+00,  3.9236e+00,
          1.0755e+00],
        [ 2.0062e-01,  2.0944e-01, -1.6319e-01, -1.0380e+01, -1.4709e+00,
         -5.5313e+00],
        [ 2.1377e-01, -4.9791e-02,  1.5049e-01, -1.4535e+01, -2.4302e+00,
         -1.1280e+01],
        [-2.2736e-01, -2.2473e-01, -9.3776e-02,  1.2261e+01,  8.2166e+00,
          5.1882e+00]], device='cuda:0', grad_fn=<CatBackward0>)
InĀ [26]:
from torchviz import make_dot

x = torch.randn(1, 365 , 6).to(device)  # Replace with appropriate input size
output = lstm(x)

# Visualize the model
dot = make_dot(output, params=dict(lstm.named_parameters()))

dot.render("lorentz_attractor_model_architecture", format="svg") 
Out[26]:
'lorentz_attractor_model_architecture.svg'

Train the model¶

InĀ [58]:
# Training loop
clip_value = 0.8
num_epochs = 20
for epoch in range(num_epochs):
    lstm.train()
    train_loss = 0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = lstm(inputs)
        #print (outputs.shape, targets.squeeze().shape )
        loss = criterion(outputs, targets.squeeze())
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_value_(lstm.parameters(), clip_value)
        optimizer.step()
        
        train_loss += loss.item()
    
    # Average train loss for this epoch
    train_loss /= len(train_loader)
    
    # Validation step (on test set)
    lstm.eval()
    test_loss = 0
    with torch.inference_mode():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs =lstm(inputs)
            loss = criterion(outputs, targets.squeeze())
            test_loss += loss.item()
    
    # Average test loss for this epoch
    test_loss /= len(test_loader)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
Epoch [1/20], Train Loss: 33.2310, Test Loss: 5.9038
Epoch [2/20], Train Loss: 30.7409, Test Loss: 1.3986
Epoch [3/20], Train Loss: 29.4363, Test Loss: 2.3024
Epoch [4/20], Train Loss: 28.6587, Test Loss: 0.8585
Epoch [5/20], Train Loss: 28.5013, Test Loss: 0.9563
Epoch [6/20], Train Loss: 28.9173, Test Loss: 1.8390
Epoch [7/20], Train Loss: 28.6386, Test Loss: 1.7320
Epoch [8/20], Train Loss: 27.7666, Test Loss: 3.1902
Epoch [9/20], Train Loss: 28.1951, Test Loss: 3.6165
Epoch [10/20], Train Loss: 28.2287, Test Loss: 1.5740
Epoch [11/20], Train Loss: 27.6656, Test Loss: 1.2109
Epoch [12/20], Train Loss: 27.6725, Test Loss: 1.8751
Epoch [13/20], Train Loss: 27.4233, Test Loss: 1.7084
Epoch [14/20], Train Loss: 26.8021, Test Loss: 1.8940
Epoch [15/20], Train Loss: 26.6064, Test Loss: 2.1320
Epoch [16/20], Train Loss: 26.6971, Test Loss: 3.1355
Epoch [17/20], Train Loss: 26.4950, Test Loss: 2.1259
Epoch [18/20], Train Loss: 26.3924, Test Loss: 1.9396
Epoch [19/20], Train Loss: 26.5422, Test Loss: 2.8664
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[58], line 22
     19     nn.utils.clip_grad_value_(lstm.parameters(), clip_value)
     20     optimizer.step()
---> 22     train_loss += loss.item()
     24 # Average train loss for this epoch
     25 train_loss /= len(train_loader)

KeyboardInterrupt: 
InĀ [Ā ]:
"""# Training loop
clip_value = 0.5
num_epochs = 20
for epoch in range(num_epochs):
    lstm.train()
    train_loss = 0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        # Forward pass
        output = lstm(inputs)

        dx_dt = (output[2:, 0] - output[:-2, 0]) / (2 * dt)
        dy_dt = (output[2:, 1] - output[:-2, 1]) / (2 * dt)
        dz_dt = (output[2:, 2] - output[:-2, 2]) / (2 * dt)

        true_dx_dt = targets.squeeze()[1:-1,3]
        true_dy_dt = targets.squeeze()[1:-1,4]
        true_dz_dt = targets.squeeze()[1:-1,5]


        loss_positions = criterion(output, targets.squeeze()[:,:3])  # Loss between predicted and true x, y, z
        loss_dx = criterion(dx_dt, true_dx_dt)/3
        loss_dy = criterion(dy_dt, true_dy_dt)/3
        loss_dz = criterion(dz_dt, true_dz_dt)/3
        
        # Total loss: combining position and derivative losses
        loss = loss_positions + loss_dx + loss_dy + loss_dz

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_value_(lstm.parameters(), clip_value)
        optimizer.step()
        
        train_loss += loss.item()
    
    # Average train loss for this epoch
    train_loss /= len(train_loader)
    


    # Validation step (on test set)
    lstm.eval()
    test_loss = 0
    with torch.inference_mode():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            output =lstm(inputs)
            loss = criterion(output, targets.squeeze()[:,:3])
            test_loss += loss.item()
    
    # Average test loss for this epoch
    test_loss /= len(test_loader)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
"""

Visualize results¶

The final section visualizes an autoregressive rollout on a selected test sequence.

  • plot_index selects which test sequence to visualize.
  • predict_amount controls how many future time steps are generated.
InĀ [58]:
plot_index=200
predict_amount=365
InĀ [59]:
A = torch.tensor(test_sequences[plot_index], dtype=torch.float32).unsqueeze(0).to(device)
A.shape
Out[59]:
torch.Size([1, 365, 6])

Autoregressive inference is performed by repeatedly predicting the next step, appending it to the current sequence, and predicting again. As the horizon increases, small errors compound, so long rollouts are expected to become progressively less accurate.

InĀ [60]:
B=A
lstm.eval()
with torch.inference_mode():
    for i in range(predict_amount):
        C = lstm(B).unsqueeze(0)
        B = torch.concat ((B, C), 1)
print (B.shape)        
A = A.cpu().numpy() 
B = B.cpu().numpy()
torch.Size([1, 730, 6])
InĀ [61]:
test_sequences.shape
Out[61]:
(3578, 365, 6)
InĀ [Ā ]:
 
InĀ [62]:
test_targets[plot_index:plot_index + predict_amount, 0, 0].shape
Out[62]:
(365,)
InĀ [63]:
#test_sequences is the input sequence
#B in input sequence + predicted things (take the last predict_amount)
#targets sequence is the targets, take the first predict_amount starting from plot index
ax = plt.axes(projection='3d')
#ax.view_init(0,90) #you can use this to move the view
ax.plot3D(test_sequences[plot_index, :, 0], test_sequences[plot_index, :, 1], test_sequences[plot_index, :, 2], 'red')
ax.plot3D(B[0, -predict_amount:, 0], B[0, -predict_amount:, 1], B[0, -predict_amount:, 2], 'blue')
ax.plot3D(test_targets[plot_index:plot_index + predict_amount, 0, 0], test_targets[plot_index:plot_index + predict_amount, 0, 1], test_targets[plot_index:plot_index + predict_amount, 0, 2], 'green')
Out[63]:
[<mpl_toolkits.mplot3d.art3d.Line3D at 0x17d828dab10>]
No description has been provided for this image
InĀ [57]:
#torch.save(lstm.state_dict(), "C:/Users/repea/Downloads/lorenz.pth")
#lstm.load_state_dict(torch.load("C:/Users/repea/Downloads/lorenz.pth"))
Out[57]:
<All keys matched successfully>