39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
import sys
|
|
import pytest
|
|
from typing import Any
|
|
sys.path.append('src/')
|
|
|
|
from protocol_components.dtypes import BprotoFieldBaseType
|
|
from parser.parser import create_ast_parser, bproto_ErrorListener
|
|
from parser.ast_visitor import BprotoASTVisitor
|
|
|
|
|
|
@pytest.mark.parametrize("source_text,expected_output", [
|
|
("uint64\n", (BprotoFieldBaseType.UINT64, 1, None)),
|
|
("float32\n", (BprotoFieldBaseType.FLOAT32, 1, None)),
|
|
("float32[8]\n", (BprotoFieldBaseType.FLOAT32, 8, None)),
|
|
])
|
|
def test_ast_field_dtype_defintion(
|
|
source_text: str,
|
|
expected_output: tuple[BprotoFieldBaseType, int, Any]):
|
|
|
|
err_listner = bproto_ErrorListener()
|
|
parser = create_ast_parser(source_text, err_listner)
|
|
ast = parser.dtype()
|
|
vinterp = BprotoASTVisitor()
|
|
|
|
assert len(err_listner.syntax_error_list) == 0
|
|
|
|
res = vinterp.visit(ast)
|
|
assert isinstance(res, tuple)
|
|
assert len(res) == 3
|
|
|
|
# Correct dtype
|
|
assert res[0] == expected_output[0].value
|
|
|
|
# Correct array size
|
|
assert res[1] == expected_output[1]
|
|
|
|
# Correct refrerence; Should normally be None be Non
|
|
assert res[2] == expected_output[2]
|