How I Trained a Neural Network to Simulate Mario Kart Gameplay
2024-11-18
Introduction
In this post, I’ll guide you through my journey of setting up a neural network (NN) capable of simulating Mario Kart gameplay, frame by frame, based on the previous frame and a steering input (left or right). This was achieved without embedding traditional game logic into the code . Let’s dive into the inspiration, technical challenges, and final implementation of this ambitious project.
Initial Inspiration
My journey began after seeing Google train an AI to play DOOM , showcasing how neural networks can excel in simulating game behavior. Although I was fascinated by the concept, my limited AI knowledge held me back.
Things changed when AI Minecraft went viral. Seeing that success rekindled my interest and convinced me this approach could work for any game. I began researching existing models and techniques.
Discovering the DIAMOND Model
I came across DIAMOND , a model trained on games like ATARI and CSGO , capable of simulating environments directly from gameplay. It struck me as the perfect foundation for adapting to a game I loved— Mario Kart . My goal was to:
- Prove feasibility by adapting DIAMOND to a different game.
- Generalize the workflow for training similar models on other games.
This blog post will walk you through my process of training the model, the challenges I faced, and the technical lessons learned. A future YouTube video will document this approach applied to another game, so stay tuned!
Game Selection: Why Not SM64?
I initially wanted to train the model on Super Mario 64 (SM64) . The game is a cultural icon and has been decompiled, offering deep insights into its mechanics. Its small native resolution and extensive community understanding made it seem like a great candidate. However, I hit several roadblocks:
- No Existing Dataset : I couldn’t find datasets for SM64, and creating one manually was daunting.
- High Data Requirements : Even with my best efforts, collecting a comprehensive dataset covering SM64’s vast and varied levels would have taken weeks of manual gameplay.
- Complex Gameplay : SM64’s range of movements (e.g., jumping, complex camera controls) made it harder to generalize.
After initial tests and discussions with the SM64 community, I concluded the scope was too large for a solo project.
Dataset Requirements for Training
Understanding the dataset format is critical to training the NN. The DIAMOND model’s dataset uses HDF5 files, structured as follows:
-
Frame Keys : Each key represents a frame and follows this pattern:
frame_<N>_x
(image data) andframe_<N>_y
(input data), whereN
is the frame number (0-999 per file). -
Image Data (
_x
):
Each frame is stored as a 3D array with dimensions(240 x 320 x 3)
:240 x 320
: Height and width.3
: RGB color channels.
(Note: Height and width often swap during tensor conversion—an issue you’ll need to debug when preprocessing.)
-
Input Data (
_y
):
A vector representing the player’s actions for that frame:- The vector’s size equals the number of input actions (e.g.,
4
for WASD controls). - Example for WASD:
[1, 0, 0, 1]
means W and D are pressed simultaneously.
- The vector’s size equals the number of input actions (e.g.,
Here’s a simple example of the input vector structure:
# Input action example: WASD controls
# No keys pressed:
action_vector = [0, 0, 0, 0]
# Pressing W and D together:
action_vector = [1, 0, 0, 1]
The original DIAMOND dataset included additional data like memory states, which I ignored because this model only simulates gameplay, not game logic.
For more on the dataset format, check out the
CSGO dataset guide
. Building the Dataset: SM64 Challenges
Since no pre-existing SM64 dataset was available, I built a simple program to record my inputs and gameplay frames. You can find the code on
my GitHub
. However, this approach posed significant challenges: - Time-Intensive : Capturing even a small portion of the game’s map took hours of manual effort.
- Insufficient Data : Initial tests with just a few minutes of gameplay led to poor results.
(Check out the preliminary output:YouTube link
).
This made it clear that training a robust NN for SM64 required either a team effort or a much larger dataset.
Pivoting to Mario Kart
After evaluating alternatives, I settled on Mario Kart 64 (MK64) . Here’s why MK64 proved a better fit:
- Easier Dataset Collection : Existing AI models for MK64 could automate gameplay recording, significantly reducing manual effort.
- Simpler Controls : Limited steering actions (left/right) and no complex camera movements make it ideal for NN training.
- Low Resolution : The game’s graphics look good even when scaled down.
Although SM64 remains a viable option with proper resources, MK64 was the better choice for a solo project. As a side note, I briefly considered a community-driven approach to collect SM64 data via a plugin for the SM64coopEX community, but this was outside my project’s scope.
Dataset Collection for Mario Kart 64
After defining a clear plan and shifting my focus to Mario Kart 64 (MK64) , I felt renewed confidence. This section covers how I set up the data collection process, including the challenges of adapting the NeuralKart emulator and generating a robust dataset.
How NeuralKart Interacts with the Emulator
The NeuralKart neural network plays MK64 on an emulator via a local TCP server . The emulator runs a Lua script to send and receive inputs from the server, enabling real-time communication. My task was to intercept and capture these interactions to build a dataset.
NeuralKart controls the emulator using a steering input ranging from
-1.0
to 1.0
: -1.0
: Full left steer.1.0
: Full right steer.0.0
: No steering.- Intermediate values represent partial steering.
This decimal steering input needed to be converted into a format compatible with the DIAMOND model. Specifically, I transformed the steering values into a one-hot vector representation.
Converting Decimal Steering to a One-Hot Vector
To convert the steering value into a vector of size 20 (representing 20 uniform intervals between -1.0 and 1.0), I implemented the following function:
def decimal_to_vector(decimal_number):
# Initialize a list of 20 zeros
vector = [0] * 20
# Ensure the decimal_number is within the expected range
decimal_number = max(-1.0, min(decimal_number, 1.0))
# Calculate the index based on the input decimal number
# Map -1.0 to 0, ..., 0.0 to 10, ..., 1.0 to 19
index = round((decimal_number + 1.0) * 10)
# Cap the index at 19 to handle the case where decimal_number is exactly 1.0
index = min(index, 19)
# Set the appropriate position in the vector to 1
vector[index] = 1
return vector
- This splits the steering interval into 20 discrete steps.
- Each vector has exactly one
1
, indicating the corresponding steering range.
For example:
- A steering value of
-1.0
maps to:[1, 0, 0, ..., 0]
. - A steering value of
0.0
maps to:[0, ..., 1, ..., 0]
(center position). - A steering value of
1.0
maps to:[0, ..., 0, 1]
.
Capturing and Storing Gameplay Data
To capture data, I modified the NeuralKart program to log the steering inputs and corresponding gameplay frames. Below is the main function I used to save data in HDF5 format , compatible with the DIAMOND model:
def capture_frame(prediction, img):
global t
global file_index
global h5file
# Create a new HDF5 file after 1000 frames
if t > 999 or h5file is None:
file_index += 1
if h5file is not None:
h5file.close()
h5file_name = file_prefix + "_" + str(file_index) + ".hdf5"
h5file = h5py.File(f"D:\\MK64AI_dataset\\{output_dir}/{h5file_name}", 'w')
t = 0
print("Saving HDF5 file and creating the next...")
# Convert the image from BGRA to BGR (OpenCV format)
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
img = cv2.resize(img, (target_width, target_height))
# Convert the prediction (decimal) to a one-hot vector
vector = decimal_to_vector(prediction)
# Save frame data to the HDF5 file
h5file.create_dataset(f"frame_{t}_x", data=img)
h5file.create_dataset(f"frame_{t}_y", data=vector)
t += 1
return True
Key features of the function:
- HDF5 File Management : Automatically creates new files after 1000 frames to keep datasets manageable.
- Image Processing : Converts gameplay frames to the required format (removing the alpha channel and resizing).
- Data Storage : Saves both image data (
X
) and steering input vectors (Y
) for each frame.
Generating Varied Training Data
To train the DIAMOND model effectively, the dataset needed to include diverse scenarios. I generated two types of data:
-
Noised Data :
Introduced randomness to the steering values produced by NeuralKart, simulating a range of driving behaviors. For example, the AI might randomly steer slightly left when it would normally go straight.- Purpose: Teach the model to generalize better and handle edge cases.
-
Expert Data :
Used only high-quality inputs from NeuralKart (without noise).- Purpose: Provide a baseline of “ideal” steering for the network to emulate.
The final dataset combined both types of data to balance realism and optimal behavior.
Dataset Limitations and Observations
While collecting data, I noticed some inherent limitations in the dataset:
- No U-Turns : Since the dataset lacked examples of full 180° turns, the trained NN could only make small corrections instead of completely reversing direction.
- Wall Collisions : Crashes into walls were rare in the dataset. When the NN encountered such a scenario, it struggled to simulate an impact and instead attempted to continue generating the track ahead.
These limitations highlight the importance of carefully curating datasets to cover edge cases.
In the clip above you can see some interesting details captured by the model. For instance, in clip #1 when the kart passes the finish line, you can see the time trial counter actually flicker.
Or in clip #2 it understood how going against a wall worked especially when entering the tunnel (in the dataset sometimes this happened!)
In the next section, I’ll explain how the dataset was used to train the DIAMOND model and the results of the training process.
Training DIAMOND for Mario Kart 64
With the dataset prepared, the next step was to train DIAMOND to interpret the data and produce meaningful gameplay. Here’s the process I followed to train DIAMOND for Mario Kart 64, adapting it from its original Counter-Strike: Global Offensive (CSGO) configuration.
Preprocessing HDF5 Files
DIAMOND’s training framework requires data preprocessing to convert the raw HDF5 files into a format suitable for training. This step involves modifying the
process_csgo_tar_files.py
script from the DIAMOND repository. By inspecting the provided GitHub commit
, I made the following adjustments: -
Update Data Paths :
- Modify the script to read HDF5 files from the
MK64AI_dataset
directory. - Ensure it processes image data (
X
) and corresponding steering vectors (Y
).
- Modify the script to read HDF5 files from the
-
Resolution Scaling :
- Ensure that data preprocessing maintains consistency between low-resolution (64x48) and full-resolution (320x240) inputs.
-
Output Directory :
- Update paths to save processed datasets in the correct locations.
Adjusting DIAMOND’s Configurations
DIAMOND’s training configuration required several adjustments to support Mario Kart 64 instead of CSGO. Here’s how I adapted the key configuration files:
-
config/agent/csgo.yaml
:- Upsampling Factor :
- For 320x240 full-resolution and 64x48 low-resolution data, the upsampling factor remained
5
(320/64 = 5
). - If resolutions differ, calculate the factor accordingly.
- For 320x240 full-resolution and 64x48 low-resolution data, the upsampling factor remained
- Example:
upsampling_factor: 5
- Upsampling Factor :
-
config/env/csgo.yaml
:- Input Resolution :
- Update the
size
field to[320, 240]
for full resolution. - Ensure the order matches the input data (
width
xheight
).
- Update the
- Number of Actions :
- Set
num_actions
to20
, reflecting the length of the steering vector.
- Set
- Data Paths :
- Update
path_data_low_res
andpath_data_full_res
to point to the preprocessed datasets.
- Update
- Input Resolution :
-
config/trainer.yaml
:- Adjust training parameters for both the denoiser and upsampler :
- Reduce
steps_per_epoch
and increase flexibility for testing different configurations. - Example (Denoiser):
denoiser: training: num_autoregressive_steps: 4 start_after_epochs: 0 steps_first_epoch: 100 steps_per_epoch: 100
- Example (Upsampler):
upsampler: training: num_autoregressive_steps: 1 start_after_epochs: 0 steps_first_epoch: 400 steps_per_epoch: 400
- Reduce
- Adjust training parameters for both the denoiser and upsampler :
Installing PyTorch with CUDA
DIAMOND requires PyTorch with GPU acceleration for efficient training. To ensure CUDA support, I manually installed the correct version of PyTorch:
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
The PyTorch website (
pytorch.org
) provides an interactive tool to generate the appropriate installation command based on your environment. Optimizing Training Parameters
Training DIAMOND for Mario Kart 64 required careful optimization of batch sizes and gradient accumulation steps to avoid GPU memory overload while maximizing efficiency. Below is a detailed explanation of the chosen parameters and reasoning based on your setup and experimentation:
Batch Size and Gradient Accumulation Steps
The training is divided into two key components: the Denoiser and the Upsampler . Each of these components required separate optimization due to differences in data resolution and GPU memory requirements.
-
Denoiser :
- Resolution : 64x48 (low resolution).
- Optimal Batch Size : 28 on a 40GB GPU.
- Gradient Accumulation Steps : 6 .
- This value is multiplied by the steps per epoch to calculate the total steps in an epoch.
- Rationale :
- The denoiser processes data at a smaller resolution, allowing for a relatively larger batch size.
- If you were to decrease the batch size you should increase the Gradient Accumulation steps.
-
Upsampler :
- Resolution : 320x240 (full resolution).
- Optimal Batch Size : 4 .
- Gradient Accumulation Steps : 1 .
- The effective batch size remains ( 4 ), reflecting the upsampler’s limited GPU memory headroom.
- Rationale :
- The full-resolution data significantly increases memory requirements.
- A much smaller batch size, reduced by the upsampling factor (~5x the data size), ensures stability during training.
Steps Per Epoch
The training process required adjustments to the number of steps per epoch to control training duration and allow for iterative testing. These values were optimized based on your experimentation:
-
Denoiser :
- Steps Per Epoch : 100 .
- Epochs : Configured for 6 epochs (( 6 \times 100 = 600 ) steps in total).
- Initial Steps for First Epoch : Set to match regular steps (( 100 )).
- Reasoning :
- The smaller step count per epoch allows for flexibility in testing configurations and monitoring performance over shorter intervals.
-
Upsampler :
- Steps Per Epoch : 400 .
- Epochs : Fewer epochs to prevent too long training time.
GPU Memory Optimization
Managing GPU memory was crucial, especially for the upsampler, which handles larger inputs. During the first epoch, memory usage was observed to be lower than in subsequent epochs. To prevent crashes:
-
Reserve Headroom :
- Avoided completely maxing out VRAM during the first epoch.
- Allowed for memory spikes in later epochs.
-
Trial and Error :
- Experimented with batch sizes and gradient accumulation steps until achieving stability.
Final Values for Training
-
Denoiser :
- Gradient Accumulation Steps: 6
- Batch Size: 28
- Steps Per Epoch: 100
- Total Steps: 600
-
Upsampler :
- Gradient Accumulation Steps: 1
- Batch Size: 4
- Steps Per Epoch: 400
- Total Steps: 400
Launching Training
To start training, I ran the modified
main.py
with the optimized parameters. Monitoring was critical to avoid GPU memory crashes: - First Epoch : Typically uses less memory than subsequent epochs, so adjustments were made to avoid crashes in later epochs.
- Training Time : On a Google Colab A100 GPU, training took approximately 10 hours and cost around $10 .
Training Outputs
Upon training completion (or interruption), DIAMOND saves results in the
outputs
directory. Key files include: checkpoints/state.pt
: Contains all training parameters (4GB, not needed for inference).agent_versions/<checkpoint>.pt
: Smaller checkpoint file (~1.42GB) used for inference.
Now that your model is trained, the next step is to make it play the game. This requires some reverse engineering to integrate the Neural Network (NN) into the Mario Kart 64 emulator. Let’s proceed step by step.
Understanding and Generating “Spawns”
The NN needs “spawns” to initialize the simulation. Spawns are essentially preconfigured starting points that include a few initial frames and corresponding actions. Since DIAMOND doesn’t have a built-in spawn creation tool for Mario Kart 64, I had to create my own. Here’s how the spawns are structured and how to generate them:
Structure of Spawns
In the CSGO DIAMOND model uploaded to HuggingFace (
here
), you can find a spawn
folder. Inside it, there are numbered folders with .npy
files and an info.json
file. The .npy
files contain the data needed for initialization: - low_res : Tensor for low-resolution frames, shape
(4, 3, 48, 64)
- full_res : Tensor for full-resolution frames, shape
(4, 3, 240, 320)
- act : Tensor for initial action vectors, shape
(4, 20)
- next_act : Tensor for the next 200 actions, shape
(200, 20)
The first three tensors (
low_res
, full_res
, and act
) correspond to the first four frames of the simulation. The next_act tensor contains the subsequent 200 action vectors, enabling the NN to “replay” these actions during the simulation. Generating Spawns
To generate spawns for Mario Kart 64, we randomly select frames and actions from the
hdf5
dataset. This includes extracting 4 frames for initialization and 200 action vectors for simulation. Here’s the process: -
Randomly Select Frames: The starting frame is randomly chosen from the range
[0, 799]
, ensuring space for the next 200 frames. For each of the first four frames:- Extract the image and action data (
X
andY
) from the HDF5 file. - Generate the low-resolution version of the image by resizing it using bicubic interpolation.
- Save both the full-res and low-res images.
Code example:i = random.randint(0, 799) for j in range(4): frame_x = f'frame_{i}_x' frame_y = f'frame_{i}_y' if frame_x in h5file and frame_y in h5file: # Append the frames data_x = h5file[frame_x][:] data_y = h5file[frame_y][:] data_x_frames.append(data_x) data_y_frames.append(data_y) # Save full-res image image = Image.fromarray(data_x) image.save(spawn_dir / f"{existing_spawns}/full_res_{j}.png") resized_image = T.resize(image, (low_res_h, low_res_w), interpolation=T.InterpolationMode.BICUBIC) low_res_frames.append(np.array(resized_image)) else: print(f"One or both of {frame_x} or {frame_y} do not exist in the file.") i += 1
- Extract the image and action data (
-
Extract the Next 200 Actions: After selecting the first four frames, iterate through the next 200 frames to save only the action vectors (
Y
data).Code example:for _ in range(200): next_act = f'frame_{i}_y' if next_act in h5file: next_act_data = h5file[next_act][:] next_act_frames.append(next_act_data)
-
Stack and Save the Tensors: After extracting all data, stack the tensors for full-res, low-res frames, and actions. Rearrange the dimensions of the images to match the required format:
(FRAME_N, COLOR_CHANNEL, H, W)
.Code example:data_x_stacked = np.stack(data_x_frames) data_y_stacked = np.stack(data_y_frames) next_act_stacked = np.stack(next_act_frames) low_res_stacked = np.stack(low_res_frames) low_res_stacked = np.transpose(low_res_stacked, (0, 3, 1, 2)) data_x_stacked = np.transpose(data_x_stacked, (0, 3, 1, 2))
-
Save the Spawns: The
spawn.py
script also saves images of the full-resolution frames so you can preview the spawns. This is helpful for selecting interesting starting points. Spawns are saved in the./training/csgo/spawn
folder, which you need to create manually.
Tips for Creating Spawns
- Generate around 10 spawns and manually inspect them to select the most interesting ones. Delete the less useful spawns to keep your setup clean.
- Check the HuggingFace model structure (
here
) for guidance on where to place thecsgo.pt
model (your trained model, renamed accordingly) and how to organize the spawn folder.
Updating Configuration Files
Lastly, ensure that your training configurations (e.g., resolution, number of actions, etc.) match the setup in your DIAMOND environment. For example:
- Update the path_data values in the spawn configuration files to point to your dataset.
- Match the
upsampling_factor
and resolution settings in the appropriate YAML files.
This completes the process of creating spawns and preparing the DIAMOND NN for Mario Kart 64 gameplay.
Tweaking the Environment to Map Keyboard Inputs to Vectors
This section involves customizing the environment to map your keyboard inputs to the action vectors used by the neural network. Since this step requires reverse engineering and depends on your specific game setup, you’ll need to experiment and adapt as necessary. Below, I’ll clarify and explain the key files and processes involved while preserving the technical depth.
Key Files and Their Roles
-
./src/csgo/keymap.py
This file defines a keymap , which maps each key (as apygame
enum) to a name. This is used to display key presses on thepygame
interface. For example, you might map theW
key to “forward” orA
to “left.” -
./src/csgo/action_processing.py
This file is critical for converting user inputs into the vectors your neural network understands and vice versa. It bridges the gap between raw inputs (like key presses) and action vectors.
-
N_KEYS
Definition
At the start of the file, update:N_KEYS = 20 # Update this to the size of your action vector
Replace20
with the number of elements in your action vector. -
Decimal to Vector and Vector to Decimal Functions
These functions map a floating-point value (e.g., steering input) to a one-hot vector and back. Here’s a closer look:-
decimal_to_vector
: Converts a decimal value (e.g., -1.0 to 1.0 for steering) into a one-hot vector of sizeN_KEYS
.def decimal_to_vector(decimal_number): vector = [0] * N_KEYS decimal_number = max(-1.0, min(decimal_number, 1.0)) index = round((decimal_number + 1.0) * 10) # Map -1.0 to 0, 1.0 to 19 index = min(index, N_KEYS - 1) # Ensure index doesn't exceed bounds vector[index] = 1 return vector
-
vector_to_decimal
: Converts a one-hot vector back to a decimal value.def vector_to_decimal(vector): if len(vector) != N_KEYS: raise ValueError("Input vector must be of length N_KEYS.") index = vector.index(1) # Find which element is "hot" return (index / 10) - 1.0
-
-
Encoding and Decoding Actions
These functions translate between high-levelCSGOAction
objects and the action vectors used by the neural network.-
encode_csgo_action
Converts user inputs into a vector. This function encodes keys, and if applicable, mouse movements. For a simplified setup without mouse inputs:def encode_csgo_action(csgo_action: CSGOAction, device: torch.device) -> torch.Tensor: keys_pressed_onehot = np.zeros(N_KEYS) keys_pressed_onehot[decimal_to_vector(csgo_action.steering_value)] = 1 return torch.tensor(keys_pressed_onehot, device=device, dtype=torch.float32)
-
decode_csgo_action
Decodes the neural network’s vector output into aCSGOAction
object for processing:def decode_csgo_action(y_preds: torch.Tensor) -> CSGOAction: y_preds = y_preds.squeeze() keys_pred = y_preds[:N_KEYS] return CSGOAction(vector_to_decimal(keys_pred))
-
-
Mapping Specific Key Inputs
In the originalDIAMOND
setup, mouse movements and clicks are included. For a keyboard-only game like Mario Kart, you can simplify this. Replace the mouse and click handling code with logic tailored to your game’s control scheme.
High-Level Action Structure
The CSGOAction Class
This class represents all inputs (keys, mouse, and clicks). For keyboard-only setups, you can strip it down to just the keys:
@dataclass
class CSGOAction:
keys: List[int] # Replace with keys alone
Key Mapping Example
To encode specific keys into vector positions, modify:
def encode_csgo_action(csgo_action: CSGOAction, device: torch.device) -> torch.Tensor:
keys_pressed_onehot = np.zeros(N_KEYS)
for key in csgo_action.keys:
if key == "w":
keys_pressed_onehot[0] = 1
elif key == "a":
keys_pressed_onehot[1] = 1
# Map remaining keys to corresponding vector indices
return torch.tensor(keys_pressed_onehot, device=device, dtype=torch.float32)
Suggestions and Challenges
-
Adjust Input Mapping to Match Your Game
Map each key to a specific index in your action vector. For example:W
→ Forward movementA
→ Turn leftD
→ Turn right
Ensure the mapping aligns with how your neural network was trained.
-
Test and Iterate
The vector encoding must precisely match what your model expects. Errors here will result in incorrect behavior. -
Handle Forbidden Movements
If certain combinations (e.g., moving in two opposite directions) are invalid, ensure your decoding process filters these.
This setup process demands rigorous testing to ensure the input-to-vector mappings align with your training configuration. The key takeaway is to mirror the logic from your training phase, effectively reconstructing the action encoding. With patience and experimentation, your NN will interpret the game controls correctly!
Generating and Using
CSGOAction
: Key Steps and Customization This section focuses on how the game generates
CSGOAction
objects and integrates them into the environment. Let’s break it down into clear steps and suggestions for customization. Understanding the Workflow
-
World Model Environment (
./src/envs/world_model_env.py
)-
reset()
Function :
This initializes the environment by callingself.generator_init.send()
, which determines a new spawn point. It sets the high-resolution and low-resolution images (stored inobs
andact
variables) for processing.
HX and CX (used for reinforcement learning) can be ignored here, as they are not relevant in this version of the code. -
step()
Function :
This function calculates the next observation, upsamples it, and interacts with the neural network. It serves as the interface between the environment and theCSGOAction
.
-
-
Game Environment (
./src/game/playenv.py
)- This file interacts with
world_model_env.py
, calling itsreset()
andstep()
functions to process game logic.
- This file interacts with
-
Key Event Handling (
./src/game/game.py
)- This is where
pygame
captures player inputs. In therun()
function:pygame.KEYDOWN
adds pressed keys to thekeys_pressed
list.pygame.KEYUP
removes released keys from thekeys_pressed
list.
- These key events directly influence the creation of the
CSGOAction
object.
if event.type == pygame.KEYDOWN: keys_pressed.append(event.key) elif event.type == pygame.KEYUP and event.key in keys_pressed: keys_pressed.remove(event.key)
The list of keys pressed (keys_pressed
) is passed to theCSGOAction
constructor. If yourCSGOAction
is modified (e.g., no mouse handling), you need to remove mouse-related parameters from its initialization. - This is where
Customizing CSGOAction
The default
CSGOAction
includes keys, mouse movements, and click data:
csgo_action = CSGOAction(keys_pressed, mouse_x, mouse_y, l_click, r_click)
In a keyboard-only setup, you can simplify it. For example:
csgo_action = CSGOAction(keys_pressed)
Then adjust the
CSGOAction
class accordingly:
@dataclass
class CSGOAction:
keys: List[int]
Customizing Controls (Mario Kart Example)
The
run()
(from ./src/game/game.py
) function can be adapted to interpret custom controls. For instance, in a Mario Kart-like setup, steering can be mapped to specific keys ( 1
to 0
), each corresponding to a fixed steering value.
def get_steering_value(keys_pressed):
if pygame.K_1 in keys_pressed:
return -1.0
elif pygame.K_2 in keys_pressed:
return -0.8
elif pygame.K_3 in keys_pressed:
return -0.6
elif pygame.K_4 in keys_pressed:
return -0.4
elif pygame.K_5 in keys_pressed:
return -0.2
elif pygame.K_6 in keys_pressed:
return 0.0
elif pygame.K_7 in keys_pressed:
return 0.2
elif pygame.K_8 in keys_pressed:
return 0.4
elif pygame.K_9 in keys_pressed:
return 0.6
elif pygame.K_0 in keys_pressed:
return 0.8
else:
return None # No relevant key pressed
Incorporate it into the
run()
function:
k = get_steering_value(keys_pressed)
if k is not None:
steering_value = k
csgo_action = CSGOAction(steering_value)
next_obs, rew, end, trunc, info = self.env.step(csgo_action)
This method ensures intuitive and responsive control without requiring a joystick or progressive input for steering.
Redirecting the Model Path (
./src/play.py
) By default,
play.py
fetches the DIAMOND snapshot from Hugging Face:
path_hf = Path(snapshot_download(repo_id="eloialonso/diamond", allow_patterns="csgo/*"))
To point it to a local directory (e.g.,
./trained/
), modify it:
path_hf = Path("./trained/")
Ensure the
./trained/
directory contains: - The game folder (
CSGO
or your equivalent). - The model and spawns.
Debugging and Troubleshooting
After making these changes, errors are likely. The key to resolving them is:
- Carefully read the stack trace to identify where the inconsistency lies.
- Check if modifications to
CSGOAction
propagate correctly through all references. - Test individual components, such as the
get_steering_value
function, to ensure they behave as intended.
Final Notes
This guide equips you to adapt the DIAMOND framework to your custom game environment, such as Mario Kart. However, performance limitations may arise:
- A high-end GPU (e.g., 3090) is needed for decent FPS (around 10 FPS).
- Lower-spec systems might experience significant lag (e.g., 0.2 FPS on a 4GB VRAM GPU).
Community Support
For additional help:
- Join the
DIAMOND Discord
and explore the#training-support
channel. - Document your debugging process, as this can aid others with similar challenges.
Conclusion
Working on this was a huge journey with its ups and downs. If you want to get a sight on my desperation while working on this and figuring it out, join the DIAMOND discord and look for my thread in
#training-support
. In there I documented all my processs in figuring this stuff out. In the next few days I’ll train the model from scratch on another game, and I’ll document all the progress with a video. This way you’ll see what my thought process is while applying it to a different game and you can follow that step by step to get familiar. In the meanwhile, hope this really depth technical blogpost was nice. Thanks for reading all of this. Feel free to msg me if you have any question or want to have a chat. (msg me on the
DIAMOND Discord
or on Telegram). AI
Workflow
Contact Me