Files
bproto/test/compiler/ast/test_ast_message.py
2025-04-14 14:43:03 +02:00

39 lines
1.1 KiB
Python

import sys
import pytest
from typing import Any
sys.path.append('src/')
from protocol_components.dtypes import BprotoFieldBaseType
from parser.parser import create_ast_parser, bproto_ErrorListener
from parser.ast_visitor import BprotoASTVisitor
@pytest.mark.parametrize("source_text,expected_output", [
("uint64\n", (BprotoFieldBaseType.UINT64, 1, None)),
("float32\n", (BprotoFieldBaseType.FLOAT32, 1, None)),
("float32[8]\n", (BprotoFieldBaseType.FLOAT32, 8, None)),
])
def test_ast_field_dtype_defintion(
source_text: str,
expected_output: tuple[BprotoFieldBaseType, int, Any]):
err_listner = bproto_ErrorListener()
parser = create_ast_parser(source_text, err_listner)
ast = parser.dtype()
vinterp = BprotoASTVisitor()
assert len(err_listner.syntax_error_list) == 0
res = vinterp.visit(ast)
assert isinstance(res, tuple)
assert len(res) == 3
# Correct dtype
assert res[0] == expected_output[0].value
# Correct array size
assert res[1] == expected_output[1]
# Correct refrerence; Should normally be None be Non
assert res[2] == expected_output[2]