Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHIA-387] DL batch upsert optimization. #17999

Merged
merged 11 commits into from
May 16, 2024
Merged
37 changes: 29 additions & 8 deletions chia/_tests/core/data_layer/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from chia._tests.core.data_layer.util import Example, add_0123_example, add_01234567_example
from chia._tests.util.misc import BenchmarkRunner, Marks, datacases
from chia._tests.util.misc import BenchmarkRunner, Marks, boolean_datacases, datacases
from chia.data_layer.data_layer_errors import KeyNotFoundError, NodeHashError, TreeGenerationIncrementingError
from chia.data_layer.data_layer_util import (
DiffData,
Expand Down Expand Up @@ -1991,20 +1991,41 @@ async def test_insert_key_already_present(data_store: DataStore, store_id: bytes


@pytest.mark.anyio
async def test_update_keys(data_store: DataStore, store_id: bytes32) -> None:
@boolean_datacases(name="use_batch_autoinsert", false="not optimized batch insert", true="optimized batch insert")
async def test_batch_insert_key_already_present(
data_store: DataStore,
store_id: bytes32,
use_batch_autoinsert: bool,
) -> None:
key = b"foo"
value = b"bar"
changelist = [{"action": "insert", "key": key, "value": value}]
await data_store.insert_batch(store_id, changelist, Status.COMMITTED, use_batch_autoinsert)
with pytest.raises(Exception, match=f"Key already present: {key.hex()}"):
await data_store.insert_batch(store_id, changelist, Status.COMMITTED, use_batch_autoinsert)


@pytest.mark.anyio
@boolean_datacases(name="use_upsert", false="update with delete and insert", true="update with upsert")
async def test_update_keys(data_store: DataStore, store_id: bytes32, use_upsert: bool) -> None:
num_keys = 10
missing_keys = 50
num_values = 10
new_keys = 10
for value in range(num_values):
changelist: List[Dict[str, Any]] = []
bytes_value = value.to_bytes(4, byteorder="big")
for key in range(num_keys + missing_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "delete", "key": bytes_key})
for key in range(num_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "insert", "key": bytes_key, "value": bytes_value})
if use_upsert:
for key in range(num_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "upsert", "key": bytes_key, "value": bytes_value})
else:
for key in range(num_keys + missing_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "delete", "key": bytes_key})
for key in range(num_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "insert", "key": bytes_key, "value": bytes_value})

await data_store.insert_batch(
store_id=store_id,
Expand Down
94 changes: 84 additions & 10 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,9 @@ async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None)
else:
await writer.execute(query, params)

async def get_leaf_at_minimum_height(self, root_hash: bytes32) -> TerminalNode:
async def get_leaf_at_minimum_height(
self, root_hash: bytes32, hash_to_parent: Dict[bytes32, InternalNode]
) -> TerminalNode:
root_node = await self.get_node(root_hash)
queue: List[Node] = [root_node]
while True:
Expand All @@ -1378,11 +1380,29 @@ async def get_leaf_at_minimum_height(self, root_hash: bytes32) -> TerminalNode:
if isinstance(node, InternalNode):
left_node = await self.get_node(node.left_hash)
right_node = await self.get_node(node.right_hash)
hash_to_parent[left_node.hash] = node
hash_to_parent[right_node.hash] = node
queue.append(left_node)
queue.append(right_node)
elif isinstance(node, TerminalNode):
return node

async def batch_upsert(
self,
tree_id: bytes32,
hash: bytes32,
to_update_hashes: Set[bytes32],
pending_upsert_new_hashes: Dict[bytes32, bytes32],
) -> bytes32:
if hash not in to_update_hashes:
return hash
node = await self.get_node(hash)
if isinstance(node, TerminalNode):
return pending_upsert_new_hashes[hash]
new_left_hash = await self.batch_upsert(tree_id, node.left_hash, to_update_hashes, pending_upsert_new_hashes)
new_right_hash = await self.batch_upsert(tree_id, node.right_hash, to_update_hashes, pending_upsert_new_hashes)
return await self._insert_internal_node(new_left_hash, new_right_hash)

async def insert_batch(
self,
store_id: bytes32,
Expand Down Expand Up @@ -1410,14 +1430,19 @@ async def insert_batch(

key_hash_frequency: Dict[bytes32, int] = {}
first_action: Dict[bytes32, str] = {}
last_action: Dict[bytes32, str] = {}

for change in changelist:
key = change["key"]
hash = key_hash(key)
key_hash_frequency[hash] = key_hash_frequency.get(hash, 0) + 1
if hash not in first_action:
first_action[hash] = change["action"]
last_action[hash] = change["action"]

pending_autoinsert_hashes: List[bytes32] = []
pending_upsert_new_hashes: Dict[bytes32, bytes32] = {}

for change in changelist:
if change["action"] == "insert":
key = change["key"]
Expand All @@ -1435,8 +1460,16 @@ async def insert_batch(
if key_hash_frequency[hash] == 1 or (
key_hash_frequency[hash] == 2 and first_action[hash] == "delete"
):
old_node = await self.maybe_get_node_by_key(key, store_id)
terminal_node_hash = await self._insert_terminal_node(key, value)
pending_autoinsert_hashes.append(terminal_node_hash)

if old_node is None:
pending_autoinsert_hashes.append(terminal_node_hash)
else:
if key_hash_frequency[hash] == 1:
raise Exception(f"Key already present: {key.hex()}")
else:
pending_upsert_new_hashes[old_node.hash] = terminal_node_hash
continue
insert_result = await self.autoinsert(
key, value, store_id, True, Status.COMMITTED, root=latest_local_root
Expand All @@ -1458,17 +1491,50 @@ async def insert_batch(
latest_local_root = insert_result.root
elif change["action"] == "delete":
key = change["key"]
hash = key_hash(key)
if key_hash_frequency[hash] == 2 and last_action[hash] == "insert" and enable_batch_autoinsert:
continue
latest_local_root = await self.delete(key, store_id, True, Status.COMMITTED, root=latest_local_root)
elif change["action"] == "upsert":
key = change["key"]
new_value = change["value"]
hash = key_hash(key)
if key_hash_frequency[hash] == 1 and enable_batch_autoinsert:
terminal_node_hash = await self._insert_terminal_node(key, new_value)
old_node = await self.maybe_get_node_by_key(key, store_id)
if old_node is not None:
pending_upsert_new_hashes[old_node.hash] = terminal_node_hash
else:
pending_autoinsert_hashes.append(terminal_node_hash)
continue
insert_result = await self.upsert(
key, new_value, store_id, True, Status.COMMITTED, root=latest_local_root
)
latest_local_root = insert_result.root
else:
raise Exception(f"Operation in batch is not insert or delete: {change}")

if len(pending_upsert_new_hashes) > 0:
to_update_hashes: Set[bytes32] = set()
for hash in pending_upsert_new_hashes.keys():
while True:
if hash in to_update_hashes:
break
to_update_hashes.add(hash)
node = await self._get_one_ancestor(hash, store_id)
if node is None:
break
hash = node.hash
assert latest_local_root is not None
assert latest_local_root.node_hash is not None
new_root_hash = await self.batch_upsert(
store_id,
latest_local_root.node_hash,
to_update_hashes,
pending_upsert_new_hashes,
)
latest_local_root = await self._insert_root(store_id, new_root_hash, Status.COMMITTED)

# Start with the leaf nodes and pair them to form new nodes at the next level up, repeating this process
# in a bottom-up fashion until a single root node remains. This constructs a balanced tree from the leaves.
while len(pending_autoinsert_hashes) > 1:
Expand All @@ -1488,14 +1554,15 @@ async def insert_batch(
if latest_local_root is None or latest_local_root.node_hash is None:
await self._insert_root(store_id=store_id, node_hash=subtree_hash, status=Status.COMMITTED)
else:
min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash)
ancestors = await self.get_ancestors_common(
node_hash=min_height_leaf.hash,
store_id=store_id,
root_hash=latest_local_root.node_hash,
generation=latest_local_root.generation,
use_optimized=True,
)
hash_to_parent: Dict[bytes32, InternalNode] = {}
min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash, hash_to_parent)
ancestors: List[InternalNode] = []
hash = min_height_leaf.hash
while hash in hash_to_parent:
node = hash_to_parent[hash]
ancestors.append(node)
hash = node.hash

await self.update_ancestor_hashes_on_insert(
store_id=store_id,
left=min_height_leaf.hash,
Expand Down Expand Up @@ -1631,6 +1698,13 @@ async def get_node_by_key_latest_generation(self, key: bytes, store_id: bytes32)
assert isinstance(node, TerminalNode)
return node

async def maybe_get_node_by_key(self, key: bytes, tree_id: bytes32) -> Optional[TerminalNode]:
try:
node = await self.get_node_by_key_latest_generation(key, tree_id)
return node
except KeyNotFoundError:
return None

async def get_node_by_key(
self,
key: bytes,
Expand Down