1
0
Fork 0
alex-retrieval/graph/graph.py

333 lines
11 KiB
Python
Raw Normal View History

2019-12-06 11:42:21 +01:00
import math
import re
2019-12-06 11:42:21 +01:00
import statistics
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.artist import setp
import util
2019-12-06 11:42:21 +01:00
def load_results(filename) -> dict:
results = defaultdict(list)
with open(f"results/{filename}", "r") as file:
2019-12-06 11:42:21 +01:00
for line in file:
if "error" in line:
continue
values = line.split()
keys = [
("num_servers", int),
("database_size", int),
("block_size", int),
("protocol_name", str),
("total_cpu_time", int),
("bits_sent", int),
("bits_received", int)
2019-12-06 11:42:21 +01:00
]
d = dict([(a[0], a[1](b)) for a, b in zip(keys, values)])
results[(d["num_servers"], d["database_size"], d["block_size"], d["protocol_name"])].append(d)
return results
def clean_results(results) -> dict:
cleaned_results = defaultdict(list)
for test, result in results.items():
cpu_time = statistics.mean(sorted([int(r["total_cpu_time"]) for r in result]))
bits_sent = statistics.mean(sorted([int(r["bits_sent"]) for r in result]))
bits_received = statistics.mean(sorted([int(r["bits_received"]) for r in result]))
cleaned_results[result[0]["protocol_name"]].append({
**result[0],
"total_cpu_time": cpu_time,
"bits_sent": bits_sent,
"bits_received": bits_received
})
2019-12-06 11:42:21 +01:00
return cleaned_results
def filter_results(results: dict, func: callable):
return {protocol_name: [r for r in results if func(r)]
for protocol_name, results in results.items()}
def save_fig(plt, title):
clean_title = re.sub(r"\W", r"_", title)
plt.savefig(f"plots/{clean_title}.pdf")
def with_bandwidth(result: dict, bandwidth=10):
return max(1, result["total_cpu_time"] + ((result["bits_sent"] + result["bits_received"]) / (bandwidth * 1000))) # 1000 bits/ms = 1 Mbit/s
def plot(all_results: dict, y_func: callable, x_func: callable, title=None, y_label=None, x_label=None,
logx=False, logy=False, scatter=False):
2019-12-06 11:42:21 +01:00
fig, ax = plt.subplots()
for protocol_name, results in all_results.items():
sorted_results = sorted(results, key=lambda r: x_func(r))
if scatter:
plot_func = ax.scatter
else:
plot_func = ax.plot
plot_func(
2019-12-06 11:42:21 +01:00
[x_func(r) for r in sorted_results],
[y_func(r) for r in sorted_results],
label=protocol_name.replace("_", " ")
2019-12-06 11:42:21 +01:00
)
#for results in all_results.values():
# for r in results:
# ax.annotate(f"{r['database_size']}, {r['block_size']}", (x_func(r), y_func(r)), fontsize=3)
if logx:
ax.set_xscale("log", basex=2)
if logy:
ax.set_yscale("log", basey=2)
2019-12-06 11:42:21 +01:00
if x_label is not None:
plt.xlabel(x_label)
if y_label is not None:
plt.ylabel(y_label)
plt.legend(loc="upper left")
#if title is not None:
# plt.title(title)
2019-12-06 11:42:21 +01:00
save_fig(plt, title)
2019-12-06 11:42:21 +01:00
#plt.show()
def plot_3x_with_simulated_bandwidth(all_results: dict, title: str):
ax1 = plt.subplot(121)
ax2 = plt.subplot(122, sharex=ax1, sharey=ax1)
ax1.set_ylabel("Time (ms)")
setp(ax2.get_yticklabels(), visible=False)
ax1.set_xlabel("Total Database Size (bits)")
ax2.set_xlabel("Total Database Size (bits)")
for ax in (ax1, ax2):
ax.tick_params("y")
ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=2)
ax1.set_title("10 Mbit/s)")
ax2.set_title("100 Mbit/s")
for protocol_name, results in all_results.items():
x_func = lambda r: r["database_size"] * r["block_size"]
sorted_results = sorted(results, key=lambda r: x_func(r))
ax1.plot(
[x_func(r) for r in sorted_results],
[with_bandwidth(r, 10) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax2.plot(
[x_func(r) for r in sorted_results],
[with_bandwidth(r, 100) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax1.legend(loc="upper left")
# fig.subplots_adjust(wspace=0)
save_fig(plt, title)
#plt.show()
def plot_send_receive(all_results: dict, title: str):
ax1 = plt.subplot(121)
ax2 = plt.subplot(122, sharex=ax1)
ax1.set_ylabel("Sent (bits)")
ax2.set_ylabel("Received (bits)")
setp(ax2.get_yticklabels(), visible=False)
ax2.yaxis.set_label_position("left")
for ax in (ax1, ax2):
ax.set_xlabel("Total Database Size (bits)")
ax.tick_params("y")
ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=2)
for protocol_name, results in all_results.items():
x_func = lambda r: r["database_size"] * r["block_size"]
sorted_results = sorted(results, key=lambda r: x_func(r))
ax1.plot(
[x_func(r) for r in sorted_results],
[max(1, r["bits_sent"]) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax2.plot(
[x_func(r) for r in sorted_results],
[max(1, r["bits_received"]) for r in sorted_results],
label=protocol_name.replace("_", " ")
)
ax1.legend(loc="upper left")
# fig.subplots_adjust(wspace=0)
save_fig(plt, title)
#plt.show()
def matrixify(results: list, x_func: callable, y_func: callable, z_func: callable):
x_labels = list(sorted(set(x_func(r) for r in results)))
y_labels = list(sorted(set(y_func(r) for r in results)))
data = {y: {x: 0 for x in x_labels}
for y in y_labels}
for r in results:
data[y_func(r)][x_func(r)] = z_func(r)
return np.array([list(y.values()) for y in data.values()]), x_labels, y_labels
def plot_scheme_heatmap(results: list, title: str, bandwidth: int):
data, x_labels, y_labels = matrixify(
results,
x_func=lambda r: r["database_size"],
y_func=lambda r: r["block_size"],
z_func=lambda r: with_bandwidth(r, bandwidth)
)
im, cbar = util.heatmap(
data,
[f"$2^{{{int(math.log2(y))}}}$" for y in y_labels],
[f"$2^{{{int(math.log2(x))}}}$" for x in x_labels],
xlabel="Database Size (bits)",
ylabel="Block Size (bits)",
cbarlabel="Time (ms)",
logcolor=True,
origin="lower",
)
save_fig(plt, title)
def plot_old_vs_new_heatmap(all_results: dict, old_func: callable, new_func: callable, title: str):
data_old, x_labels, y_labels = matrixify(
old_func(all_results),
x_func=lambda r: r["database_size"],
y_func=lambda r: r["block_size"],
z_func=lambda r: with_bandwidth(r, 10)
)
data_new, x_labels, y_labels = matrixify(
new_func(all_results),
x_func=lambda r: r["database_size"],
y_func=lambda r: r["block_size"],
z_func=lambda r: with_bandwidth(r, 10)
)
def calc(i, j):
try:
if (data_new[i, j], data_old[i, j]) == (0, 0):
return 0
diff = data_new[i, j] - data_old[i, j]
if diff == 0:
return 1
return diff
except IndexError:
return 0
im, cbar = util.heatmap(
np.array([[calc(i, j) for j, y in enumerate(x)] for i, x in enumerate(data_new)]),
[f"$2^{{{int(math.log2(y))}}}$" for y in y_labels],
[f"$2^{{{int(math.log2(x))}}}$" for x in x_labels],
xlabel="Database Size (bits)",
ylabel="Block Size (bits)",
cbarlabel="Time Difference (ms)",
sym_logcolor=True,
origin="lower",
)
save_fig(plt, title)
def main():
# Simple CPU Time
2019-12-06 11:42:21 +01:00
plot(
filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1),
y_label="Time (ms)",
x_label="Total Database Size (bits)",
title="Computation Time - 1-bit Block Size",
2019-12-06 11:42:21 +01:00
y_func=lambda r: max(1, r["total_cpu_time"]),
x_func=lambda r: r["database_size"] * r["block_size"],
logx=True,
logy=True
)
plt.close()
# ... with simulated bandwidth, e.g. estimated total real time
plot_3x_with_simulated_bandwidth(
filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1),
title="Total Time with Simulated Bandwidth - 1-bit Block Size"
2019-12-06 11:42:21 +01:00
)
# CPU Time per bit as a function of block/database-ratio
#plot(
# filter_results(clean_results(load_results("results_combined.log")),
# lambda r: r["protocol_name"] != "Interpolation" and r["database_size"] * r["block_size"] > 1024),
# y_label="Time (ms)",
# x_label="Block Size / Database Size (ratio)",
# title="Computation Time per bit - Block Size / Database Size Ratio",
# y_func=lambda r: max(1, r["total_cpu_time"] / (r["database_size"] * r["block_size"])),
# x_func=lambda r: r["block_size"] / r["database_size"],
# logx=True
#)
plt.close()
# Simple Network Traffic
plot_send_receive(
filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1),
title="Network Traffic - 1-bit Block Size"
2019-12-06 11:42:21 +01:00
)
# Scatter-plot of total real-time (cpu + simulated bandwidth), varying both block size and database size
2019-12-06 11:42:21 +01:00
#plot(
# clean_results(load_results("results_fast_var-bs_var-db.log")),
# y_label="Time (ms)",
# x_label="Total Database Size (bits)",
# title="Total Time with Simulated Bandwidth - Varying Block and Database Size",
# y_func=lambda r: max(1, r["total_cpu_time"] + ((r["bits_sent"]+r["bits_received"])/(10*1000))), # 1000 bits/ms = 1 Mbit/s
# x_func=lambda r: r["database_size"] * r["block_size"],
# scatter=True
2019-12-06 11:42:21 +01:00
#)
plt.close()
# 2D Heatmap of CPU time for Simple/XOR/Balanced XOR - varying both database size and block size
plot_scheme_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log"))["Send_All"],
title="Total Simulated Time Heatmap: Send All - Varying Database Size and Block Size - 10Mbit/s",
bandwidth=10
)
plt.close()
plot_scheme_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log"))["XOR"],
title="Total Simulated Time Heatmap: XOR - Varying Database Size and Block Size - 10Mbit/s",
bandwidth=10
)
plt.close()
plot_scheme_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log"))["Balanced_XOR"],
title="Total Simulated Time Heatmap: Balanced XOR - Varying Database Size and Block Size - 10Mbit/s",
bandwidth=10
)
plt.close()
# 2D Heatmaps of Schemes Versus (CPU + simulated bandwidth), varying both block size and database size
plot_old_vs_new_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log")),
old_func=lambda rs: rs["Send_All"],
new_func=lambda rs: rs["Balanced_XOR"],
title="Total Simulated Time Heatmap: Send All vs Balanced XOR - Varying Database Size and Block Size - 10 Mbit/s"
)
plt.close()
plot_old_vs_new_heatmap(
clean_results(load_results("results_fast_var-bs_var-db.log")),
old_func=lambda rs: rs["XOR"],
new_func=lambda rs: rs["Balanced_XOR"],
title="Total Simulated Time Heatmap: XOR vs Balanced XOR - Varying Database Size and Block Size - 10 Mbit/s"
)
plt.close()
if __name__ == '__main__':
main()