import sys import pytest from typing import Any sys.path.append('src/') from protocol_components.dtypes import BprotoFieldBaseType from parser.parser import create_ast_parser, bproto_ErrorListener from parser.ast_visitor import BprotoASTVisitor @pytest.mark.parametrize("source_text,expected_output", [ ("uint64\n", (BprotoFieldBaseType.UINT64, 1, None)), ("float32\n", (BprotoFieldBaseType.FLOAT32, 1, None)), ("float32[8]\n", (BprotoFieldBaseType.FLOAT32, 8, None)), ]) def test_ast_field_dtype_defintion( source_text: str, expected_output: tuple[BprotoFieldBaseType, int, Any]): err_listner = bproto_ErrorListener() parser = create_ast_parser(source_text, err_listner) ast = parser.dtype() vinterp = BprotoASTVisitor() assert len(err_listner.syntax_error_list) == 0 res = vinterp.visit(ast) assert isinstance(res, tuple) assert len(res) == 3 # Correct dtype assert res[0] == expected_output[0].value # Correct array size assert res[1] == expected_output[1] # Correct refrerence; Should normally be None be Non assert res[2] == expected_output[2]