Strange error at test runs

I’m stuck on Stage 3

While importing inputs to the main() from sys I get strange error at the stdout, but when I tested it with the same string, though, hardcoded - output was correct.

Here are my logs:

[tester::#SZ4] Running tests for Stage #SZ4 (Print table names)
remote: [tester::#SZ4] Creating test.db with tables: [coconut coffee grape vanilla watermelon]
remote: [tester::#SZ4] $ ./your_program.sh test.db .tables
remote: [your_program]
remote: [tester::#SZ4] Expected stdout to contain "coconut coffee grape vanilla watermelon", got: "\n"
remote: [tester::#SZ4] Test failed

And here’s a snippet of my code:

import struct
import sys
from dataclasses import dataclass

def parse_varint(database_file):
    """
    Reads a variable-length integer (varint) from the database file.
    Returns the decoded integer.
    """
    value = 0
    shift = 0
    
    for _ in range(9):  # Varints can be up to 9 bytes long
        byte = database_file.read(1)
        if not byte:
            raise EOFError("Unexpected end of file while reading varint.")
        byte = byte[0]
        value = (value << 7) | (byte & 0x7F)
        if (byte & 0x80) == 0:
            return value
    # The 9th byte is shifted by 8 instead of 7
    byte = database_file.read(1)
    if not byte:
        raise EOFError("Unexpected end of file while reading varint.")
    return (value << 8) | byte[0]

def read_varint_from_bytes(data, offset):
    """
    Reads a varint from bytes starting at offset.
    Returns the integer value and the number of bytes consumed.
    """
    value = 0
    num_bytes = 0
    for _ in range(9):
        if offset + num_bytes >= len(data):
            raise ValueError("Not enough bytes to read varint.")
        byte = data[offset + num_bytes]
        value = (value << 7) | (byte & 0x7F)
        num_bytes += 1
        if (byte & 0x80) == 0:
            return value, num_bytes
    # The 9th byte is shifted by 8 instead of 7
    if offset + num_bytes >= len(data):
        raise ValueError("Not enough bytes to read varint.")
    byte = data[offset + num_bytes]
    value = (value << 8) | byte
    num_bytes += 1
    return value, num_bytes

def parse_record(database_file):
    """
    Parses a record from the database file, correctly handling different data types.
    Returns a list of column values.
    """
    # Record the current file offset to track the start of the header
    header_start_offset = database_file.tell()
    
    # Read the size of the header (stored as a varint)
    header_size = parse_varint(database_file)
    header_bytes_read = database_file.tell() - header_start_offset
    
    # Read the remaining header bytes
    remaining_header_bytes = header_size - header_bytes_read
    header_bytes = database_file.read(remaining_header_bytes)
    
    # Parse serial types from header_bytes
    serial_types = []
    offset = 0
    while offset < len(header_bytes):
        serial_type, num_bytes = read_varint_from_bytes(header_bytes, offset)
        serial_types.append(serial_type)
        offset += num_bytes
    
    # Data starts immediately after the header
    # No need to seek as we're already at the correct position
    # Parse the data values based on the serial types
    column_values = []
    for serial_type in serial_types:
        value = parse_serial_type(database_file, serial_type)
        column_values.append(value)
        
    return column_values

def parse_serial_type(database_file, serial_type):
    """
    Reads data from the database file based on the serial type.
    Returns the corresponding value.
    """
    if serial_type == 0:  # NULL
        return None
    elif serial_type == 1:  # 1-byte signed integer
        data = database_file.read(1)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 2:  # 2-byte signed integer
        data = database_file.read(2)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 3:  # 3-byte signed integer
        data = b'\x00' + database_file.read(3)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 4:  # 4-byte signed integer
        data = database_file.read(4)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 5:  # 6-byte signed integer
        data = b'\x00' + database_file.read(6)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 6:  # 8-byte signed integer
        data = database_file.read(8)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 7:  # 8-byte floating point number
        data = database_file.read(8)
        return struct.unpack(">d", data)[0]
    elif serial_type == 8:  # Integer value 0
        return 0
    elif serial_type == 9:  # Integer value 1
        return 1
    elif serial_type >= 12 and serial_type % 2 == 0:  # BLOB
        length = (serial_type - 12) // 2
        return database_file.read(length)  # Store BLOB data as bytes
    elif serial_type >= 13 and serial_type % 2 == 1:  # TEXT
        length = (serial_type - 13) // 2
        data = database_file.read(length)
        return data.decode("utf-8", errors="replace")
    else:
        raise ValueError(f"Unsupported serial type: {serial_type}")

def parse_page_header(database_file):
    """
    Parses a page header as per SQLite format.
    Returns a dictionary representing the page header.
    """
    header_format = ">BHHHB"
    header_size = struct.calcsize(header_format)
    header_data = database_file.read(header_size)
    if len(header_data) != header_size:
        raise EOFError("Failed to read page header.")
    page_type, first_free_block_start, number_of_cells, start_of_content_area, fragmented_free_bytes = struct.unpack(header_format, header_data)
    
    return {
        "page_type": page_type,
        "first_free_block_start": first_free_block_start,
        "number_of_cells": number_of_cells,
        "start_of_content_area": start_of_content_area,
        "fragmented_free_bytes": fragmented_free_bytes,
    }

def parse_cell_pointers(database_file, number_of_cells):
    """
    Reads cell pointers from the cell pointer array.
    """
    cell_pointers = []
    for _ in range(number_of_cells):
        cell_pointer_data = database_file.read(2)
        if len(cell_pointer_data) != 2:
            raise EOFError("Failed to read cell pointer.")
        cell_pointer = int.from_bytes(cell_pointer_data, "big")
        cell_pointers.append(cell_pointer)
    return cell_pointers

def parse_sqlite_schema(database_file, cell_pointers):
    """
    Parses each cell in the sqlite_schema table and extracts rows.
    """
    sqlite_schema_rows = []
    for cell_pointer in cell_pointers:
        database_file.seek(cell_pointer)
        #_payload_size = parse_varint(database_file)  # Read payload size
        record = parse_record(database_file)
        if len(record) >= 5:
            sqlite_schema_rows.append({
                "type": record[0],
                "name": record[1],
                "tbl_name": record[2],
                "rootpage": record[3],
                "sql": record[4],
            })
    return sqlite_schema_rows

def main():

    # Read command-line arguments
    database_file_path = sys.argv[1]
    command = sys.argv[2]

    if command in [".dbinfo", ".tables"]:
        with open(database_file_path, "rb") as database_file:
            
            # Skip the database file header (first 100 bytes)
            database_file.seek(100)
            
            # Parse the page header
            page_header = parse_page_header(database_file)
            number_of_cells = page_header["number_of_cells"]
            
            # Parse cell pointers from the Cell Pointer Array
            cell_pointers = parse_cell_pointers(database_file, number_of_cells)
            
            # Extract rows from the sqlite_schema table
            sqlite_schema_rows = parse_sqlite_schema(database_file, cell_pointers)
            
            if command == ".dbinfo":
                page_size = int.from_bytes(database_file.read(2), byteorder="big")
                print(f"database page size: {page_size}")
                # Count only the tables (exclude indexes and other objects)
                table_count = sum(
                    1 for row in sqlite_schema_rows
                    if row["type"] == "table" and not row["tbl_name"].startswith('sqlite_')
                )
                print(f"number of tables: {table_count}")
            elif command == ".tables":
                # Display all table names
                table_names = [
                    row["tbl_name"] for row in sqlite_schema_rows
                    if row["type"] == "table" and not row["tbl_name"].startswith('sqlite_')
                ]
                print(" ".join(table_names))
    else:
        print(f"Invalid command: {command}")

main()

@FykerF Could you elaborate a bit on “importing inputs to the main() from sys” and the hardcoded string? I’d like to replicate the correct output to better understand how your code works.


I tried to run the code with codecrafters test --previous but it couldn’t pass the first stage anymore:

The database page size resides in the database file header, so should be read before database_file.seek(100).

Hi!

Here is basic inputs to main:

  database_file_path = sys.argv[1]
  command = sys.argv[2]

while i do

 database_file_path = 'sample.db
 command = '.tables'/'.dbinfo'

You are correct, I forgot about page_size. Thank you for pointing it out!

Here is corrected version where both page_size and names of the tables are correctly parsed:

import struct
import sys
from dataclasses import dataclass

def parse_varint(database_file):
    """
    Reads a variable-length integer (varint) from the database file.
    Returns the decoded integer.
    """
    value = 0
    shift = 0
    
    for _ in range(9):  # Varints can be up to 9 bytes long
        byte = database_file.read(1)
        if not byte:
            raise EOFError("Unexpected end of file while reading varint.")
        byte = byte[0]
        value = (value << 7) | (byte & 0x7F)
        if (byte & 0x80) == 0:
            return value
    # The 9th byte is shifted by 8 instead of 7
    byte = database_file.read(1)
    if not byte:
        raise EOFError("Unexpected end of file while reading varint.")
    return (value << 8) | byte[0]

def read_varint_from_bytes(data, offset):
    """
    Reads a varint from bytes starting at offset.
    Returns the integer value and the number of bytes consumed.
    """
    value = 0
    num_bytes = 0
    for _ in range(9):
        if offset + num_bytes >= len(data):
            raise ValueError("Not enough bytes to read varint.")
        byte = data[offset + num_bytes]
        value = (value << 7) | (byte & 0x7F)
        num_bytes += 1
        if (byte & 0x80) == 0:
            return value, num_bytes
    # The 9th byte is shifted by 8 instead of 7
    if offset + num_bytes >= len(data):
        raise ValueError("Not enough bytes to read varint.")
    byte = data[offset + num_bytes]
    value = (value << 8) | byte
    num_bytes += 1
    return value, num_bytes

def parse_record(database_file):
    """
    Parses a record from the database file, correctly handling different data types.
    Returns a list of column values.
    """
    # Record the current file offset to track the start of the header
    header_start_offset = database_file.tell()
    
    # Read the size of the header (stored as a varint)
    header_size = parse_varint(database_file)
    header_bytes_read = database_file.tell() - header_start_offset
    
    # Read the remaining header bytes
    remaining_header_bytes = header_size - header_bytes_read
    header_bytes = database_file.read(remaining_header_bytes)
    
    # Parse serial types from header_bytes
    serial_types = []
    offset = 0
    while offset < len(header_bytes):
        serial_type, num_bytes = read_varint_from_bytes(header_bytes, offset)
        serial_types.append(serial_type)
        offset += num_bytes
    
    # Data starts immediately after the header
    # No need to seek as we're already at the correct position
    # Parse the data values based on the serial types
    column_values = []
    for serial_type in serial_types:
        value = parse_serial_type(database_file, serial_type)
        column_values.append(value)
        
    return column_values

def parse_serial_type(database_file, serial_type):
    """
    Reads data from the database file based on the serial type.
    Returns the corresponding value.
    """
    if serial_type == 0:  # NULL
        return None
    elif serial_type == 1:  # 1-byte signed integer
        data = database_file.read(1)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 2:  # 2-byte signed integer
        data = database_file.read(2)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 3:  # 3-byte signed integer
        data = b'\x00' + database_file.read(3)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 4:  # 4-byte signed integer
        data = database_file.read(4)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 5:  # 6-byte signed integer
        data = b'\x00' + database_file.read(6)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 6:  # 8-byte signed integer
        data = database_file.read(8)
        return int.from_bytes(data, byteorder="big", signed=True)
    elif serial_type == 7:  # 8-byte floating point number
        data = database_file.read(8)
        return struct.unpack(">d", data)[0]
    elif serial_type == 8:  # Integer value 0
        return 0
    elif serial_type == 9:  # Integer value 1
        return 1
    elif serial_type >= 12 and serial_type % 2 == 0:  # BLOB
        length = (serial_type - 12) // 2
        return database_file.read(length)  # Store BLOB data as bytes
    elif serial_type >= 13 and serial_type % 2 == 1:  # TEXT
        length = (serial_type - 13) // 2
        data = database_file.read(length)
        return data.decode("utf-8", errors="replace")
    else:
        raise ValueError(f"Unsupported serial type: {serial_type}")

def parse_page_header(database_file):
    """
    Parses a page header as per SQLite format.
    Returns a dictionary representing the page header.
    """
    header_format = ">BHHHB"
    header_size = struct.calcsize(header_format)
    header_data = database_file.read(header_size)
    if len(header_data) != header_size:
        raise EOFError("Failed to read page header.")
    page_type, first_free_block_start, number_of_cells, start_of_content_area, fragmented_free_bytes = struct.unpack(header_format, header_data)
    
    return {
        "page_type": page_type,
        "first_free_block_start": first_free_block_start,
        "number_of_cells": number_of_cells,
        "start_of_content_area": start_of_content_area,
        "fragmented_free_bytes": fragmented_free_bytes,
    }

def parse_cell_pointers(database_file, number_of_cells):
    """
    Reads cell pointers from the cell pointer array.
    """
    cell_pointers = []
    for _ in range(number_of_cells):
        cell_pointer_data = database_file.read(2)
        if len(cell_pointer_data) != 2:
            raise EOFError("Failed to read cell pointer.")
        cell_pointer = int.from_bytes(cell_pointer_data, "big")
        cell_pointers.append(cell_pointer)
    return cell_pointers

def parse_sqlite_schema(database_file, cell_pointers):
    """
    Parses each cell in the sqlite_schema table and extracts rows.
    """
    sqlite_schema_rows = []
    for cell_pointer in cell_pointers:
        database_file.seek(cell_pointer)
        _payload_size = parse_varint(database_file)  # Read payload size
        _rowid = parse_varint(database_file)  # Read the rowid
        record = parse_record(database_file)
        if len(record) >= 5:
            sqlite_schema_rows.append({
                "type": record[0],
                "name": record[1],
                "tbl_name": record[2],
                "rootpage": record[3],
                "sql": record[4],
            })
    return sqlite_schema_rows

def main():
    database_file_path = sys.argv[1]
    command = sys.argv[2]
    
    if command in [".dbinfo", ".tables"]:
        with open(database_file_path, "rb") as database_file:
            # Skip the database file header (first 100 bytes)
            database_file.seek(100)
            
            # Parse the page header
            page_header = parse_page_header(database_file)
            number_of_cells = page_header["number_of_cells"]
            
            # Parse cell pointers from the Cell Pointer Array
            cell_pointers = parse_cell_pointers(database_file, number_of_cells)
            
            # Extract rows from the sqlite_schema table
            sqlite_schema_rows = parse_sqlite_schema(database_file, cell_pointers)
            
            if command == ".dbinfo":
                # Count only the tables (exclude indexes and other objects)
                database_file.seek(16)
                page_size = int.from_bytes(database_file.read(2), byteorder="big")
                print(f"database page size: {page_size}")
                # Count only the tables (exclude indexes and other objects)
                table_count = sum(
                    1 for row in sqlite_schema_rows
                    if row["type"] == "table" and not row["tbl_name"].startswith('sqlite_')
                )
                print(f"number of tables: {table_count}")
            elif command == ".tables":
                # Display all table names
                table_names = [
                    row["tbl_name"] for row in sqlite_schema_rows
                    if row["type"] == "table" and not row["tbl_name"].startswith('sqlite_')
                ]
                print(" ".join(table_names))
    else:
        print(f"Invalid command: {command}")

if __name__ == "__main__":
    main()

1 Like

This topic was automatically closed 5 days after the last reply. New replies are no longer allowed.