"""Module for plotting Counter-Strike data."""
import io
import math
import warnings
from dataclasses import dataclass
from typing import Any, Literal, Optional
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
import tqdm
from matplotlib.axes import Axes
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from PIL import Image
import awpy.data
import awpy.plot.utils
@dataclass
class PointSettings:
"""Dataclass for settings used in plotting with default values for each attribute.
Attributes:
marker (str): Marker style for plotting. Defaults to 'o'.
color (str): Color of the marker. Defaults to 'red'.
size (int): Size of the marker. Defaults to 10.
hp (Optional[int]): Health points (0-100). Defaults to None.
armor (Optional[int]): Armor points (0-100). Defaults to None.
direction (Optional[Tuple[float, float]]): (pitch, yaw) angles. Defaults to None.
label (Optional[str]): Text label for the point. Defaults to None.
alpha (float): Transparency level (0.0 - 1.0). Defaults to 1.0.
from_dict: Creates PointSettings object from a dictionary. Raises ValueError
if unexpected keys are passed or if wrong types are provided.
"""
marker: str = "o"
color: str = "red"
size: int = 10
hp: Optional[int] = None
armor: Optional[int] = None
direction: Optional[tuple[float, float]] = None
label: Optional[str] = None
alpha: float = 1.0
@classmethod
def from_dict(cls: type["PointSettings"], settings: dict[str, Any]) -> "PointSettings":
"""Create a PointSettings instance from a dictionary, throwing ValueError for unknown keys.
Args:
settings (dict): Dictionary of settings.
Raises:
ValueError: If unknown keys are provided.
Returns:
PointSettings: Validated instance.
"""
allowed_keys = {f.name for f in cls.__dataclass_fields__.values()}
unknown_keys = set(settings) - allowed_keys
if unknown_keys:
err_msg = f"Unknown keys provided: {unknown_keys}"
raise ValueError(err_msg)
return cls(**settings)
@dataclass
class PlotPositionMetadata:
"""Data structure holding transformed plotting data for a player's position.
Attributes:
x_pos (tuple[float, float]): Transformed x coordinates.
y_pos (tuple[float, float]): Transformed y coordinates.
plot_settings (PointSettings): PointSettings for the plotted pair.
"""
x_pos: tuple[float, float]
y_pos: tuple[float, float]
plot_settings: PointSettings
[docs]
def plot(
map_name: str,
points: list[tuple[float, float, float]] | None = None,
lower_points_frac: float | None = 0.4,
point_settings: list[PointSettings] | list[dict[str, Any]] | None = None,
) -> tuple[Figure, Axes]:
"""Plot a Counter-Strike map with optional points.
Args:
map_name (str): Name of the map to plot. E.g. "de_dust2"
("dust2" or "de_dust2.png" will not work).
points (list[tuple[float, float, float]], optional):
list of points to plot. Each point is (X, Y, Z). Defaults to None.
lower_points_frac (optional, float): The factor by which to multiply
the opacity of a given point if it is on the lower level of the
map and `map_name` is NOT referencing the lower level (i.e.
`map_name` does not end in "_lower"). Defaults to 0.4.
If `map_name` is referencing the lower level of a map (i.e.
ends in "_lower") then this argument is ignored and lower points'
alpha is set to 1 and upper points' alpha is set to 0.
point_settings (list[PointSettings], list[dict[str, Any]], optional):
list of PointSettings objects or dictionaries with settings for each point.
Each dictionary should contain:
- 'marker': str (default 'o')
- 'color': str (default 'red')
- 'size': float (default 10)
- 'hp': int (0-100)
- 'armor': int (0-100)
- 'direction': tuple[float, float] (pitch, yaw in degrees)
- 'label': str (optional)
Raises:
FileNotFoundError: Raises a FileNotFoundError if the map image is not
found.
ValueError: Raises a ValueError if the number of points and
point_settings don't match.
Returns:
tuple[Figure, Axes]: Matplotlib Figure and Axes objects.
"""
image = f"{map_name}.png"
map_name = map_name.removesuffix("_lower")
# Check for the main map image
map_img_path = awpy.data.MAPS_DIR / image
if not map_img_path.exists():
map_img_path_err = f"Map image not found: {map_img_path}. Might need to call `awpy get maps`"
raise FileNotFoundError(map_img_path_err)
map_bg = mpimg.imread(map_img_path)
figure, axes = plt.subplots(figsize=(1024 / 300, 1024 / 300), dpi=300)
axes.imshow(map_bg, zorder=0)
axes.axis("off")
# Plot points if provided
if points is not None:
_plot_positions(map_name, axes, points, lower_points_frac, point_settings)
figure.patch.set_facecolor("black")
plt.tight_layout()
return figure, axes
def _plot_positions(
map_name: str,
axes: Axes,
points: list[tuple[float, float, float]] | None = None,
lower_points_frac: float | None = 0.4,
point_settings: list[PointSettings] | list[dict[str, Any]] | None = None,
) -> None:
"""Plots points on a map, optionally customizing plotting settings.
This function plots player positions or any set of 3D coordinates onto a 2D map. It supports
customizations for each point, including marker style, size, color, labels, health (HP) bars,
armor bars, and directional indicators. It also adjusts transparency for points on different
map levels (upper/lower) based on the `lower_points_frac` parameter.
Args:
map_name (str): Name of the map to plot. E.g. "de_dust2"
("dust2" or "de_dust2.png" will not work).
axes (matplotlib.axes.Axes): The matplotlib axes object to plot the points onto.
points (list[tuple[float, float, float]], optional):
list of points to plot. Each point is (X, Y, Z). Defaults to None.
lower_points_frac (optional, float): The factor by which to multiply
the opacity of a given point if it is on the lower level of the
map and `map_name` is NOT referencing the lower level (i.e.
`map_name` does not end in "_lower"). Defaults to 0.4.
If `map_name` is referencing the lower level of a map (i.e.
ends in "_lower") then this argument is ignored and lower points'
alpha is set to 1 and upper points' alpha is set to 0.
point_settings (list[PointSettings], list[dict[str, Any]], optional):
list of PointSettings or dictionaries with settings for each point. Each dictionary
should contain:
- 'marker': str (default 'o')
- 'color': str (default 'red')
- 'size': float (default 10)
- 'hp': int (0-100)
- 'armor': int (0-100)
- 'direction': tuple[float, float] (pitch, yaw in degrees)
- 'label': str (optional)
Raises:
ValueError: If the number of points does not match the number of `point_settings` entries.
Returns:
None
"""
# Ensure points and settings have the same length
if point_settings is None:
point_settings = [PointSettings.from_dict({})] * len(points)
elif len(points) != len(point_settings):
settings_mismatch_err = "Number of points and point_settings do not match."
raise ValueError(settings_mismatch_err)
else:
# If dicts are passed into the function
# convert them to PointSettings objects
for i in range(len(point_settings)):
if isinstance(point_settings[i], dict):
point_settings[i] = PointSettings.from_dict(point_settings[i])
plot_metadata = _generate_plot_metadata(map_name, points, point_settings, lower_points_frac)
_plot_positions_from_metadata(plot_metadata, axes)
def _generate_plot_metadata(
map_name: str,
points: list[tuple[float, float, float]],
point_settings: list[PointSettings],
lower_points_frac: float = 0.4,
) -> list[PlotPositionMetadata]:
"""Processes points and their settings to prepare plotting metadata.
Args:
map_name (str): Name of the map.
points (list[tuple[float, float, float]]): List of (x, y, z) points.
point_settings (list[PointSettings]): List of settings for each point.
lower_points_frac (float, optional): Opacity scaling for lower-level points. Defaults to 0.4.
Returns:
PlotPositionMetadata: Object containing transformed coordinates and updated PointSettings.
"""
plot_position_metadata_list = []
map_is_lower = map_name.endswith("_lower")
for (x, y, z), settings in zip(points, point_settings, strict=False):
transformed_x = awpy.plot.utils.game_to_pixel_axis(map_name, x, "x")
transformed_y = awpy.plot.utils.game_to_pixel_axis(map_name, y, "y")
# Skip points outside map bounds
if transformed_x < 0 or transformed_x > 1024 or transformed_y < 0 or transformed_y > 1024:
continue
# Calculate alpha
alpha = 0.15 if settings.hp == 0 else 1.0
point_is_lower = awpy.plot.utils.is_position_on_lower_level(map_name, (x, y, z))
if not map_is_lower and point_is_lower:
if lower_points_frac == 0:
continue
alpha *= lower_points_frac
elif map_is_lower and not point_is_lower:
continue
# Create a new PointSettings instance with updated alpha
updated_point_settings = PointSettings(
marker=settings.marker,
color=settings.color,
size=settings.size,
hp=settings.hp,
armor=settings.armor,
direction=settings.direction,
label=settings.label,
alpha=alpha,
)
# Store transformed coordinates and updated settings in PlotPositionMetadata obj
# and add it to the list
plot_position_metadata_list.append(
PlotPositionMetadata(x_pos=transformed_x, y_pos=transformed_y, plot_settings=updated_point_settings)
)
return plot_position_metadata_list
def _plot_positions_from_metadata(player_pos_settings: list[PlotPositionMetadata], axes: Axes) -> None:
"""Plots player positions and associated metadata on a given matplotlib axes.
This function visualizes player positions on the map using the provided plot metadata.
It handles the plotting of markers, health (HP) bars, armor bars, directional arrows,
and optional labels for each player or point.
Args:
player_pos_settings (list[PlotPositionMetadata]):
Contains the transformed (x, y) positions and plotting settings for each point.
Includes attributes such as marker style, color, size, HP, armor, direction,
label, and transparency (alpha).
axes (matplotlib.axes.Axes):
The matplotlib axes object where the positions and related visual elements
will be plotted.
Raises:
ValueError: If data in `PlotPositionMetadata` is malformed or missing required fields.
Returns:
None
"""
for metadata in player_pos_settings:
transformed_x = metadata.x_pos
transformed_y = metadata.y_pos
settings = metadata.plot_settings
# Get settings
marker = settings.marker
color = settings.color
size = settings.size
hp = settings.hp
armor = settings.armor
direction = settings.direction
label = settings.label
alpha = settings.alpha
# Plot the marker
axes.plot(
transformed_x,
transformed_y,
marker=marker,
color="black",
markersize=size,
alpha=alpha,
zorder=10,
) # Black outline
axes.plot(
transformed_x,
transformed_y,
marker=marker,
color=color,
markersize=size * 0.9,
alpha=alpha,
zorder=11,
) # Inner color
# Set bar sizes and offsets
bar_width = size * 2
bar_length = size * 6
vertical_offset = size * 3.5
if hp and hp > 0:
# Plot HP bar (red background)
hp_bar_full = Rectangle(
(transformed_x - bar_length / 2, transformed_y + vertical_offset),
bar_length,
bar_width,
facecolor="red",
edgecolor="black",
alpha=alpha,
zorder=11,
)
axes.add_patch(hp_bar_full)
# Plot HP bar (actual health)
hp_bar = Rectangle(
(transformed_x - bar_length / 2, transformed_y + vertical_offset),
bar_length * hp / 100,
bar_width,
facecolor="green",
edgecolor="black",
alpha=alpha,
zorder=11,
)
axes.add_patch(hp_bar)
# Plot Armor bar (lightgrey background)
armor_bar = Rectangle(
(
transformed_x - bar_length / 2,
transformed_y + vertical_offset + bar_width,
),
bar_length,
bar_width,
facecolor="lightgrey",
edgecolor="black",
alpha=alpha,
zorder=11,
)
axes.add_patch(armor_bar)
# Plot Armor bar (actual armor)
armor_bar = Rectangle(
(
transformed_x - bar_length / 2,
transformed_y + vertical_offset + bar_width,
),
bar_length * armor / 100,
bar_width,
facecolor="grey",
edgecolor="black",
alpha=alpha,
zorder=11,
)
axes.add_patch(armor_bar)
# Plot direction
if direction and hp > 0:
pitch, yaw = direction
dx = math.cos(math.radians(yaw)) * math.cos(math.radians(pitch))
dy = math.sin(math.radians(yaw)) * math.cos(math.radians(pitch))
line_length = size * 2
axes.plot(
[transformed_x, transformed_x + dx * line_length],
[transformed_y, transformed_y + dy * line_length],
color="black",
alpha=alpha,
linewidth=1,
zorder=12,
)
# Add label
if label:
label_offset = vertical_offset + 1.25 * bar_width
axes.annotate(
label,
(transformed_x, transformed_y - label_offset),
xytext=(0, 0),
textcoords="offset points",
color="white",
fontsize=6,
alpha=alpha,
zorder=13,
ha="center",
va="top",
) # Center the text horizontally
def _generate_frame_plot(
map_name: str,
frames_data: list[dict],
lower_points_frac: float | None = 0.4,
) -> list[Image.Image]:
"""Generate frames for the animation.
Args:
map_name (str): Name of the map to plot. E.g. "de_dust2"
("dust2" or "de_dust2.png" will not work).
frames_data (list[dict]): list of dictionaries, each containing
'points' and 'point_settings' for a frame.
lower_points_frac (optional, float): The factor by which to multiply
the opacity of a given point if it is on the lower level of the
map and `map_name` is NOT referencing the lower level (i.e.
`map_name` does not end in "_lower"). Defaults to 0.4.
If `map_name` is referencing the lower level of a map (i.e.
ends in "_lower") then this argument is ignored and lower points'
alpha is set to 1 and upper points' alpha is set to 0.
Returns:
list[Image.Image]: list of PIL Image objects representing each frame.
"""
frames = []
for frame_data in tqdm.tqdm(frames_data):
fig, _ax = plot(
map_name,
frame_data["points"],
lower_points_frac,
frame_data["point_settings"],
)
# Convert the matplotlib figure to a PIL Image
buf = io.BytesIO()
fig.savefig(buf, format="png", facecolor="black")
buf.seek(0)
img = Image.open(buf)
frames.append(img)
plt.close(fig) # Close the figure to free up memory
return frames
[docs]
def gif(
map_name: str,
frames_data: list[dict],
output_filename: str,
duration: int = 500,
lower_points_frac: float | None = 0.4,
) -> None:
"""Create an animated gif from a list of frames.
Args:
map_name (str): Name of the map to plot. E.g. "de_dust2"
("dust2" or "de_dust2.png" will not work).
frames_data (list[dict]): list of dictionaries, each containing
'points' and 'point_settings' for a frame.
frames (list[Image.Image]): list of PIL Image objects.
output_filename (str): Name of the output GIF file.
duration (int): Duration of each frame in milliseconds.
lower_points_frac (optional, float): The factor by which to multiply
the opacity of a given point if it is on the lower level of the
map and `map_name` is NOT referencing the lower level (i.e.
`map_name` does not end in "_lower"). Defaults to 0.4.
If `map_name` is referencing the lower level of a map (i.e.
ends in "_lower") then this argument is ignored and lower points'
alpha is set to 1 and upper points' alpha is set to 0.
"""
frames = _generate_frame_plot(
map_name,
frames_data,
lower_points_frac,
)
frames[0].save(
output_filename,
save_all=True,
append_images=frames[1:],
duration=duration,
loop=0,
)
def _hex_plot(
ax: Axes,
x: list[float],
y: list[float],
size: int,
cmap: str,
alpha: float,
alpha_range: list[float] | None,
min_alpha: float,
max_alpha: float,
) -> Axes:
"""Returns an `ax` with a hex plot."""
# Create heatmap
heatmap = ax.hexbin(x, y, gridsize=size, cmap=cmap, alpha=alpha)
# Get array of counts in each hexbin
counts = heatmap.get_array()
if alpha_range is not None:
# Normalize counts to use as alpha values
alphas = counts / counts.max()
alphas = alphas * (max_alpha - min_alpha) + min_alpha
# Update the color alpha values
heatmap.set_alpha(alphas)
# Set counts of 0 to NaN to make them transparent
counts[counts == 0] = np.nan
heatmap.set_array(counts)
return ax
def _hist_plot(
ax: Axes,
x: list[float],
y: list[float],
size: int,
cmap: str,
alpha: float,
alpha_range: list[float] | None,
min_alpha: float,
max_alpha: float,
) -> Axes:
"""Returns an `ax` with a hist plot."""
hist, xedges, yedges = np.histogram2d(x, y, bins=size)
x, y = np.meshgrid(xedges[:-1], yedges[:-1])
# Set counts of 0 to NaN to make them transparent
hist[hist == 0] = np.nan
if alpha_range is not None:
# Normalize histogram values
hist_norm = hist.T / hist.max()
# Create a color array with variable alpha
colors = plt.cm.get_cmap(cmap)(hist_norm)
colors[..., -1] = np.where(
np.isnan(hist_norm),
0,
hist_norm * (max_alpha - min_alpha) + min_alpha,
)
# Plot the heatmap
_heatmap = ax.pcolormesh(x, y, hist.T, cmap=cmap, norm=LogNorm(), alpha=colors)
else:
_heatmap = ax.pcolormesh(x, y, hist.T, cmap=cmap, norm=LogNorm(), alpha=alpha)
return ax
def _kde_plot(
ax: Axes,
x: list[float],
y: list[float],
size: int,
cmap: str,
alpha: float,
alpha_range: list[float] | None,
min_alpha: float,
max_alpha: float,
kde_lower_bound: float = 0.1,
) -> Axes:
"""Returns an `ax` with a kde plot."""
# Calculate the kernel density estimate
xy = np.vstack([x, y])
kde = scipy.stats.gaussian_kde(xy)
# Create a grid and evaluate the KDE on it
xmin, xmax = min(x), max(x)
ymin, ymax = min(y), max(y)
xi, yi = np.mgrid[xmin : xmax : size * 1j, ymin : ymax : size * 1j]
zi = kde(np.vstack([xi.flatten(), yi.flatten()])).reshape(xi.shape)
# Set very low density values to NaN to make them transparent
threshold = zi.max() * kde_lower_bound # You can adjust this threshold
zi[zi < threshold] = np.nan
if alpha_range is not None:
# Normalize KDE values
zi_norm = zi / zi.max()
# Create a color array with variable alpha
colors = plt.cm.get_cmap(cmap)(zi_norm)
colors[..., -1] = np.where(
np.isnan(zi_norm),
0,
zi_norm * (max_alpha - min_alpha) + min_alpha,
)
_heatmap = ax.pcolormesh(xi, yi, zi, cmap=cmap, alpha=colors)
else:
_heatmap = ax.pcolormesh(xi, yi, zi, cmap=cmap, alpha=alpha)
return ax
def verify_alpha_range(alpha_range: list[float]) -> list:
"""Verify that `alpha_range` is valid."""
if len(alpha_range) != 2:
msg = "alpha_range must have exactly 2 elements."
raise ValueError(msg)
min_val, max_val = alpha_range[0], alpha_range[1]
if not (min_val >= 0 and min_val <= 1) or not (max_val >= 0 and max_val <= 1):
msg = "alpha_range must have both values as floats between \
0 and 1."
raise ValueError(msg)
if min_val > max_val:
msg = "alpha_range[0] (min alpha) cannot be greater than \
alpha[1] (max alpha)."
raise ValueError(msg)
return [min_val, max_val]
[docs]
def heatmap(
map_name: str,
points: list[tuple[float, float, float]],
method: Literal["hex", "hist", "kde"],
size: int = 10,
cmap: str = "RdYlGn",
alpha: float = 0.5,
*,
alpha_range: list[float] | None = None,
kde_lower_bound: float = 0.1,
) -> tuple[Figure, Axes]:
"""Create a heatmap of points on a Counter-Strike map.
Args:
map_name (str): Name of the map to plot. E.g. "de_dust2"
("dust2" or "de_dust2.png" will not work).
points (list[tuple[float, float, float]]): list of points to plot.
method (Literal["hex", "hist", "kde"]): Method to use for the heatmap.
size (int, optional): Size of the heatmap grid. Defaults to 10.
cmap (str, optional): Colormap to use. Defaults to 'RdYlGn'.
alpha (float, optional): Transparency of the heatmap. Defaults to 0.5.
alpha_range (list[float, float], optional): When value is provided
here, points' transparency will vary based on the density, with
min transparency of `alpha_range[0]` and max of `alpha_range[1]`.
Defaults to `None`, meaning no variance of transparency.
kde_lower_bound (float, optional): Lower bound for KDE density
values. Defaults to 0.1.
Raises:
ValueError: Raises a ValueError if an invalid method is provided.
Returns:
tuple[Figure, Axes]: Matplotlib Figure and Axes objects
"""
fig, ax = plt.subplots(figsize=(1024 / 300, 1024 / 300), dpi=300)
image = f"{map_name}.png"
map_is_lower = map_name.endswith("_lower")
if map_is_lower:
map_name = map_name.removesuffix("_lower")
# Load and display the map
map_img_path = awpy.data.MAPS_DIR / image
if not map_img_path.exists():
map_img_path_err = f"Map image not found: {map_img_path}. Might need to call `awpy get maps`"
raise FileNotFoundError(map_img_path_err)
map_bg = mpimg.imread(map_img_path)
ax.imshow(map_bg, zorder=0, alpha=0.5)
temp_points = points
points = []
warning = ""
for point in temp_points:
point_is_lower = awpy.plot.utils.is_position_on_lower_level(map_name, point)
# If point is on same level as map, then keep, else ignore & warn.
if point_is_lower == map_is_lower:
points.append(point)
else:
warning = f"You are drawing on the {'lower' if map_is_lower else 'upper'} level of the map, but provided some points on the {'lower' if point_is_lower else 'upper'} level, which were ignored." # noqa: E501
if warning:
warnings.warn(warning, UserWarning, stacklevel=2)
x, y = [], []
for point in points:
x_point = awpy.plot.utils.game_to_pixel_axis(map_name, point[0], "x")
y_point = awpy.plot.utils.game_to_pixel_axis(map_name, point[1], "y")
# Check if the point is within bounds of the map image
if x_point < 0 or x_point > 1024 or y_point < 0 or y_point > 1024:
continue
x.append(x_point)
y.append(y_point)
# Check and/or set alpha_range
min_alpha, max_alpha = 0, 1
if alpha_range is not None:
min_alpha, max_alpha = verify_alpha_range(alpha_range)
if method == "hex":
ax = _hex_plot(
ax,
x,
y,
size,
cmap,
alpha,
alpha_range,
min_alpha,
max_alpha,
)
elif method == "hist":
ax = _hist_plot(
ax,
x,
y,
size,
cmap,
alpha,
alpha_range,
min_alpha,
max_alpha,
)
elif method == "kde":
ax = _kde_plot(
ax,
x,
y,
size,
cmap,
alpha,
alpha_range,
min_alpha,
max_alpha,
kde_lower_bound,
)
ax.axis("off")
fig.patch.set_facecolor("black")
plt.tight_layout()
return fig, ax