I’m stuck on Stage 3
While importing inputs to the main() from sys I get strange error at the stdout, but when I tested it with the same string, though, hardcoded - output was correct.
Here are my logs:
[tester::#SZ4] Running tests for Stage #SZ4 (Print table names)
remote: [tester::#SZ4] Creating test.db with tables: [coconut coffee grape vanilla watermelon]
remote: [tester::#SZ4] $ ./your_program.sh test.db .tables
remote: [your_program]
remote: [tester::#SZ4] Expected stdout to contain "coconut coffee grape vanilla watermelon", got: "\n"
remote: [tester::#SZ4] Test failed
And here’s a snippet of my code:
import struct
import sys
from dataclasses import dataclass
def parse_varint(database_file):
"""
Reads a variable-length integer (varint) from the database file.
Returns the decoded integer.
"""
value = 0
shift = 0
for _ in range(9): # Varints can be up to 9 bytes long
byte = database_file.read(1)
if not byte:
raise EOFError("Unexpected end of file while reading varint.")
byte = byte[0]
value = (value << 7) | (byte & 0x7F)
if (byte & 0x80) == 0:
return value
# The 9th byte is shifted by 8 instead of 7
byte = database_file.read(1)
if not byte:
raise EOFError("Unexpected end of file while reading varint.")
return (value << 8) | byte[0]
def read_varint_from_bytes(data, offset):
"""
Reads a varint from bytes starting at offset.
Returns the integer value and the number of bytes consumed.
"""
value = 0
num_bytes = 0
for _ in range(9):
if offset + num_bytes >= len(data):
raise ValueError("Not enough bytes to read varint.")
byte = data[offset + num_bytes]
value = (value << 7) | (byte & 0x7F)
num_bytes += 1
if (byte & 0x80) == 0:
return value, num_bytes
# The 9th byte is shifted by 8 instead of 7
if offset + num_bytes >= len(data):
raise ValueError("Not enough bytes to read varint.")
byte = data[offset + num_bytes]
value = (value << 8) | byte
num_bytes += 1
return value, num_bytes
def parse_record(database_file):
"""
Parses a record from the database file, correctly handling different data types.
Returns a list of column values.
"""
# Record the current file offset to track the start of the header
header_start_offset = database_file.tell()
# Read the size of the header (stored as a varint)
header_size = parse_varint(database_file)
header_bytes_read = database_file.tell() - header_start_offset
# Read the remaining header bytes
remaining_header_bytes = header_size - header_bytes_read
header_bytes = database_file.read(remaining_header_bytes)
# Parse serial types from header_bytes
serial_types = []
offset = 0
while offset < len(header_bytes):
serial_type, num_bytes = read_varint_from_bytes(header_bytes, offset)
serial_types.append(serial_type)
offset += num_bytes
# Data starts immediately after the header
# No need to seek as we're already at the correct position
# Parse the data values based on the serial types
column_values = []
for serial_type in serial_types:
value = parse_serial_type(database_file, serial_type)
column_values.append(value)
return column_values
def parse_serial_type(database_file, serial_type):
"""
Reads data from the database file based on the serial type.
Returns the corresponding value.
"""
if serial_type == 0: # NULL
return None
elif serial_type == 1: # 1-byte signed integer
data = database_file.read(1)
return int.from_bytes(data, byteorder="big", signed=True)
elif serial_type == 2: # 2-byte signed integer
data = database_file.read(2)
return int.from_bytes(data, byteorder="big", signed=True)
elif serial_type == 3: # 3-byte signed integer
data = b'\x00' + database_file.read(3)
return int.from_bytes(data, byteorder="big", signed=True)
elif serial_type == 4: # 4-byte signed integer
data = database_file.read(4)
return int.from_bytes(data, byteorder="big", signed=True)
elif serial_type == 5: # 6-byte signed integer
data = b'\x00' + database_file.read(6)
return int.from_bytes(data, byteorder="big", signed=True)
elif serial_type == 6: # 8-byte signed integer
data = database_file.read(8)
return int.from_bytes(data, byteorder="big", signed=True)
elif serial_type == 7: # 8-byte floating point number
data = database_file.read(8)
return struct.unpack(">d", data)[0]
elif serial_type == 8: # Integer value 0
return 0
elif serial_type == 9: # Integer value 1
return 1
elif serial_type >= 12 and serial_type % 2 == 0: # BLOB
length = (serial_type - 12) // 2
return database_file.read(length) # Store BLOB data as bytes
elif serial_type >= 13 and serial_type % 2 == 1: # TEXT
length = (serial_type - 13) // 2
data = database_file.read(length)
return data.decode("utf-8", errors="replace")
else:
raise ValueError(f"Unsupported serial type: {serial_type}")
def parse_page_header(database_file):
"""
Parses a page header as per SQLite format.
Returns a dictionary representing the page header.
"""
header_format = ">BHHHB"
header_size = struct.calcsize(header_format)
header_data = database_file.read(header_size)
if len(header_data) != header_size:
raise EOFError("Failed to read page header.")
page_type, first_free_block_start, number_of_cells, start_of_content_area, fragmented_free_bytes = struct.unpack(header_format, header_data)
return {
"page_type": page_type,
"first_free_block_start": first_free_block_start,
"number_of_cells": number_of_cells,
"start_of_content_area": start_of_content_area,
"fragmented_free_bytes": fragmented_free_bytes,
}
def parse_cell_pointers(database_file, number_of_cells):
"""
Reads cell pointers from the cell pointer array.
"""
cell_pointers = []
for _ in range(number_of_cells):
cell_pointer_data = database_file.read(2)
if len(cell_pointer_data) != 2:
raise EOFError("Failed to read cell pointer.")
cell_pointer = int.from_bytes(cell_pointer_data, "big")
cell_pointers.append(cell_pointer)
return cell_pointers
def parse_sqlite_schema(database_file, cell_pointers):
"""
Parses each cell in the sqlite_schema table and extracts rows.
"""
sqlite_schema_rows = []
for cell_pointer in cell_pointers:
database_file.seek(cell_pointer)
#_payload_size = parse_varint(database_file) # Read payload size
record = parse_record(database_file)
if len(record) >= 5:
sqlite_schema_rows.append({
"type": record[0],
"name": record[1],
"tbl_name": record[2],
"rootpage": record[3],
"sql": record[4],
})
return sqlite_schema_rows
def main():
# Read command-line arguments
database_file_path = sys.argv[1]
command = sys.argv[2]
if command in [".dbinfo", ".tables"]:
with open(database_file_path, "rb") as database_file:
# Skip the database file header (first 100 bytes)
database_file.seek(100)
# Parse the page header
page_header = parse_page_header(database_file)
number_of_cells = page_header["number_of_cells"]
# Parse cell pointers from the Cell Pointer Array
cell_pointers = parse_cell_pointers(database_file, number_of_cells)
# Extract rows from the sqlite_schema table
sqlite_schema_rows = parse_sqlite_schema(database_file, cell_pointers)
if command == ".dbinfo":
page_size = int.from_bytes(database_file.read(2), byteorder="big")
print(f"database page size: {page_size}")
# Count only the tables (exclude indexes and other objects)
table_count = sum(
1 for row in sqlite_schema_rows
if row["type"] == "table" and not row["tbl_name"].startswith('sqlite_')
)
print(f"number of tables: {table_count}")
elif command == ".tables":
# Display all table names
table_names = [
row["tbl_name"] for row in sqlite_schema_rows
if row["type"] == "table" and not row["tbl_name"].startswith('sqlite_')
]
print(" ".join(table_names))
else:
print(f"Invalid command: {command}")
main()