1
0
Fork 0
alex-retrieval/graph/graph.py

333 lines
11 KiB
Python

import math
import re
import statistics
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.artist import setp
import util
def load_results(filename) -> dict:
results = defaultdict(list)
with open(f"results/{filename}", "r") as file:
for line in file:
if "error" in line:
continue
values = line.split()
keys = [
("num_servers", int),
("database_size", int),
("block_size", int),
("protocol_name", str),
("total_cpu_time", int),
("bits_sent", int),
("bits_received", int)
]
d = dict([(a[0], a[1](b)) for a, b in zip(keys, values)])
results[(d["num_servers"], d["database_size"], d["block_size"], d["protocol_name"])].append(d)
return results
def clean_results(results) -> dict:
cleaned_results = defaultdict(list)
for test, result in results.items():
cpu_time = statistics.mean(sorted([int(r["total_cpu_time"]) for r in result]))
bits_sent = statistics.mean(sorted([int(r["bits_sent"]) for r in result]))
bits_received = statistics.mean(sorted([int(r["bits_received"]) for r in result]))
cleaned_results[result[0]["protocol_name"]].append({
**result[0],
"total_cpu_time": cpu_time,
"bits_sent": bits_sent,
"bits_received": bits_received
})
return cleaned_results
def filter_results(results: dict, func: callable):
return {protocol_name: [r for r in results if func(r)]
for protocol_name, results in results.items()}
def save_fig(plt, title):
clean_title = re.sub(r"\W", r"_", title)
plt.savefig(f"plots/{clean_title}.pdf")
def with_bandwidth(result: dict, bandwidth=10):
return max(1, result["total_cpu_time"] + ((result["bits_sent"] + result["bits_received"]) / (bandwidth * 1000))) # 1000 bits/ms = 1 Mbit/s
def plot(all_results: dict, y_func: callable, x_func: callable, title=None, y_label=None, x_label=None,
logx=False, logy=False, scatter=False):
fig, ax = plt.subplots()
for protocol_name, results in all_results.items():
sorted_results = sorted(results, key=lambda r: x_func(r))
if scatter:
plot_func = ax.scatter
else:
plot_func = ax.plot
plot_func(
[x_func(r) for r in sorted_results],
[y_func(r) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
#for results in all_results.values():
# for r in results:
# ax.annotate(f"{r['database_size']}, {r['block_size']}", (x_func(r), y_func(r)), fontsize=3)
if logx:
ax.set_xscale("log", basex=2)
if logy:
ax.set_yscale("log", basey=2)
if x_label is not None:
plt.xlabel(x_label)
if y_label is not None:
plt.ylabel(y_label)
plt.legend(loc="upper left")
#if title is not None:
# plt.title(title)
save_fig(plt, title)
#plt.show()
def plot_3x_with_simulated_bandwidth(all_results: dict, title: str):
ax1 = plt.subplot(121)
ax2 = plt.subplot(122, sharex=ax1, sharey=ax1)
ax1.set_ylabel("Time (ms)")
setp(ax2.get_yticklabels(), visible=False)
ax1.set_xlabel("Total Database Size (bits)")
ax2.set_xlabel("Total Database Size (bits)")
for ax in (ax1, ax2):
ax.tick_params("y")
ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=2)
ax1.set_title("10 Mbit/s)")
ax2.set_title("100 Mbit/s")
for protocol_name, results in all_results.items():
x_func = lambda r: r["database_size"] * r["block_size"]
sorted_results = sorted(results, key=lambda r: x_func(r))
ax1.plot(
[x_func(r) for r in sorted_results],
[with_bandwidth(r, 10) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax2.plot(
[x_func(r) for r in sorted_results],
[with_bandwidth(r, 100) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax1.legend(loc="upper left")
# fig.subplots_adjust(wspace=0)
save_fig(plt, title)
#plt.show()
def plot_send_receive(all_results: dict, title: str):
ax1 = plt.subplot(121)
ax2 = plt.subplot(122, sharex=ax1)
ax1.set_ylabel("Sent (bits)")
ax2.set_ylabel("Received (bits)")
setp(ax2.get_yticklabels(), visible=False)
ax2.yaxis.set_label_position("left")
for ax in (ax1, ax2):
ax.set_xlabel("Total Database Size (bits)")
ax.tick_params("y")
ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=2)
for protocol_name, results in all_results.items():
x_func = lambda r: r["database_size"] * r["block_size"]
sorted_results = sorted(results, key=lambda r: x_func(r))
ax1.plot(
[x_func(r) for r in sorted_results],
[max(1, r["bits_sent"]) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax2.plot(
[x_func(r) for r in sorted_results],
[max(1, r["bits_received"]) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax1.legend(loc="upper left")
# fig.subplots_adjust(wspace=0)
save_fig(plt, title)
#plt.show()
def matrixify(results: list, x_func: callable, y_func: callable, z_func: callable):
x_labels = list(sorted(set(x_func(r) for r in results)))
y_labels = list(sorted(set(y_func(r) for r in results)))
data = {y: {x: 0 for x in x_labels}
for y in y_labels}
for r in results:
data[y_func(r)][x_func(r)] = z_func(r)
return np.array([list(y.values()) for y in data.values()]), x_labels, y_labels
def plot_scheme_heatmap(results: list, title: str, bandwidth: int):
data, x_labels, y_labels = matrixify(
results,
x_func=lambda r: r["database_size"],
y_func=lambda r: r["block_size"],
z_func=lambda r: with_bandwidth(r, bandwidth)
)
im, cbar = util.heatmap(
data,
[f"$2^{{{int(math.log2(y))}}}$" for y in y_labels],
[f"$2^{{{int(math.log2(x))}}}$" for x in x_labels],
xlabel="Database Size (bits)",
ylabel="Block Size (bits)",
cbarlabel="Time (ms)",
logcolor=True,
origin="lower",
)
save_fig(plt, title)
def plot_old_vs_new_heatmap(all_results: dict, old_func: callable, new_func: callable, title: str):
data_old, x_labels, y_labels = matrixify(
old_func(all_results),
x_func=lambda r: r["database_size"],
y_func=lambda r: r["block_size"],
z_func=lambda r: with_bandwidth(r, 10)
)
data_new, x_labels, y_labels = matrixify(
new_func(all_results),
x_func=lambda r: r["database_size"],
y_func=lambda r: r["block_size"],
z_func=lambda r: with_bandwidth(r, 10)
)
def calc(i, j):
try:
if (data_new[i, j], data_old[i, j]) == (0, 0):
return 0
diff = data_new[i, j] - data_old[i, j]
if diff == 0:
return 1
return diff
except IndexError:
return 0
im, cbar = util.heatmap(
np.array([[calc(i, j) for j, y in enumerate(x)] for i, x in enumerate(data_new)]),
[f"$2^{{{int(math.log2(y))}}}$" for y in y_labels],
[f"$2^{{{int(math.log2(x))}}}$" for x in x_labels],
xlabel="Database Size (bits)",
ylabel="Block Size (bits)",
cbarlabel="Time Difference (ms)",
sym_logcolor=True,
origin="lower",
)
save_fig(plt, title)
def main():
# Simple CPU Time
plot(
filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1),
y_label="Time (ms)",
x_label="Total Database Size (bits)",
title="Computation Time - 1-bit Block Size",
y_func=lambda r: max(1, r["total_cpu_time"]),
x_func=lambda r: r["database_size"] * r["block_size"],
logx=True,
logy=True
)
plt.close()
# ... with simulated bandwidth, e.g. estimated total real time
plot_3x_with_simulated_bandwidth(
filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1),
title="Total Time with Simulated Bandwidth - 1-bit Block Size"
)
# CPU Time per bit as a function of block/database-ratio
#plot(
# filter_results(clean_results(load_results("results_combined.log")),
# lambda r: r["protocol_name"] != "Interpolation" and r["database_size"] * r["block_size"] > 1024),
# y_label="Time (ms)",
# x_label="Block Size / Database Size (ratio)",
# title="Computation Time per bit - Block Size / Database Size Ratio",
# y_func=lambda r: max(1, r["total_cpu_time"] / (r["database_size"] * r["block_size"])),
# x_func=lambda r: r["block_size"] / r["database_size"],
# logx=True
#)
plt.close()
# Simple Network Traffic
plot_send_receive(
filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1),
title="Network Traffic - 1-bit Block Size"
)
# Scatter-plot of total real-time (cpu + simulated bandwidth), varying both block size and database size
#plot(
# clean_results(load_results("results_fast_var-bs_var-db.log")),
# y_label="Time (ms)",
# x_label="Total Database Size (bits)",
# title="Total Time with Simulated Bandwidth - Varying Block and Database Size",
# y_func=lambda r: max(1, r["total_cpu_time"] + ((r["bits_sent"]+r["bits_received"])/(10*1000))), # 1000 bits/ms = 1 Mbit/s
# x_func=lambda r: r["database_size"] * r["block_size"],
# scatter=True
#)
plt.close()
# 2D Heatmap of CPU time for Simple/XOR/Balanced XOR - varying both database size and block size
plot_scheme_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log"))["Send_All"],
title="Total Simulated Time Heatmap: Send All - Varying Database Size and Block Size - 10Mbit/s",
bandwidth=10
)
plt.close()
plot_scheme_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log"))["XOR"],
title="Total Simulated Time Heatmap: XOR - Varying Database Size and Block Size - 10Mbit/s",
bandwidth=10
)
plt.close()
plot_scheme_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log"))["Balanced_XOR"],
title="Total Simulated Time Heatmap: Balanced XOR - Varying Database Size and Block Size - 10Mbit/s",
bandwidth=10
)
plt.close()
# 2D Heatmaps of Schemes Versus (CPU + simulated bandwidth), varying both block size and database size
plot_old_vs_new_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log")),
old_func=lambda rs: rs["Send_All"],
new_func=lambda rs: rs["Balanced_XOR"],
title="Total Simulated Time Heatmap: Send All vs Balanced XOR - Varying Database Size and Block Size - 10 Mbit/s"
)
plt.close()
plot_old_vs_new_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log")),
old_func=lambda rs: rs["XOR"],
new_func=lambda rs: rs["Balanced_XOR"],
title="Total Simulated Time Heatmap: XOR vs Balanced XOR - Varying Database Size and Block Size - 10 Mbit/s"
)
plt.close()
if __name__ == '__main__':
main()