(Translated by https://www.hiragana.jp/)
[PyCDE] Support for adding module constants to manifests by teqdruid · Pull Request #7510 · llvm/circt · GitHub
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyCDE] Support for adding module constants to manifests #7510

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions frontends/PyCDE/integration_test/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pycde
from pycde import (AppID, Clock, Module, Reset, modparams, generator)
from pycde.bsp import cosim
from pycde.common import Constant
from pycde.constructs import Reg, Wire
from pycde.esi import FuncService, MMIO, MMIOReadWriteCmdType
from pycde.types import (Bits, Channel, UInt)
Expand All @@ -15,21 +16,23 @@
import sys


class LoopbackInOutAdd7(Module):
class LoopbackInOutAdd(Module):
"""Loopback the request from the host, adding 7 to the first 15 bits."""
clk = Clock()
rst = Reset()

add_amt = Constant(UInt(16), 11)

@generator
def construct(ports):
loopback = Wire(Channel(UInt(16)))
args = FuncService.get_call_chans(AppID("loopback_add7"),
args = FuncService.get_call_chans(AppID("add"),
arg_type=UInt(24),
result=loopback)

ready = Wire(Bits(1))
data, valid = args.unwrap(ready)
plus7 = data + 7
plus7 = data + LoopbackInOutAdd.add_amt.value
data_chan, data_ready = loopback.type.wrap(plus7.as_uint(16), valid)
data_chan_buffered = data_chan.buffer(ports.clk, ports.rst, 5)
ready.assign(data_ready)
Expand Down Expand Up @@ -96,7 +99,7 @@ class Top(Module):

@generator
def construct(ports):
LoopbackInOutAdd7(clk=ports.clk, rst=ports.rst)
LoopbackInOutAdd(clk=ports.clk, rst=ports.rst, appid=AppID("loopback"))
for i in range(4, 18, 5):
MMIOClient(i)()
MMIOReadWriteClient(clk=ports.clk, rst=ports.rst)
Expand Down
16 changes: 12 additions & 4 deletions frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,21 @@ def read_offset_check(i: int, add_amt: int):
print(m.type_table)

d = acc.build_accelerator()

recv = d.ports[esi.AppID("loopback_add7")].read_port("result")
loopback = d.children[esi.AppID("loopback")]
recv = loopback.ports[esi.AppID("add")].read_port("result")
recv.connect()

send = d.ports[esi.AppID("loopback_add7")].write_port("arg")
send = loopback.ports[esi.AppID("add")].write_port("arg")
send.connect()

loopback_info = None
for mod_info in m.module_infos:
if mod_info.name == "LoopbackInOutAdd":
loopback_info = mod_info
break
assert loopback_info is not None
add_amt = mod_info.constants["add_amt"].value

################################################################################
# Loopback add 7 tests
################################################################################
Expand All @@ -85,6 +93,6 @@ def read_offset_check(i: int, add_amt: int):

print(f"data: {data}")
print(f"resp: {resp}")
assert resp == data + 7
assert resp == data + add_amt

print("PASS")
17 changes: 17 additions & 0 deletions frontends/PyCDE/src/pycde/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,23 @@ def __repr__(self) -> str:
return f"{self.name}[{self.index}]"


class Constant:
"""A constant value associated with a module. Gets added to the ESI system
manifest so it is accessible at runtime.

Example usage:

```
def ExampleModule(Module):
const_name = Constant(UInt(16), 42)
```
"""

def __init__(self, type: Type, value: object):
self.type = type
self.value = value


class _PyProxy:
"""Parent class for a Python object which has a corresponding IR op (i.e. a
proxy class)."""
Expand Down
21 changes: 18 additions & 3 deletions frontends/PyCDE/src/pycde/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Set, Tuple, Dict

from .common import (AppID, Clock, Input, ModuleDecl, Output, PortError,
_PyProxy, Reset)
from .support import (get_user_loc, _obj_to_attribute, create_type_string,
from .common import (AppID, Clock, Constant, Input, ModuleDecl, Output,
PortError, _PyProxy, Reset)
from .support import (get_user_loc, _obj_to_attribute, obj_to_typed_attribute,
create_const_zero)
from .signals import ClockSignal, Signal, _FromCirctValue
from .types import ClockType, Type, _FromCirctType
Expand Down Expand Up @@ -237,6 +237,7 @@ def scan_cls(self):
clock_ports = set()
reset_ports = set()
generators = {}
constants = {}
num_inputs = 0
num_outputs = 0
for attr_name, attr in self.cls_dct.items():
Expand Down Expand Up @@ -273,11 +274,14 @@ def scan_cls(self):
ports.append(attr)
elif isinstance(attr, Generator):
generators[attr_name] = attr
elif isinstance(attr, Constant):
constants[attr_name] = attr

self.ports = ports
self.clocks = clock_ports
self.resets = reset_ports
self.generators = generators
self.constants = constants

def create_port_proxy(self) -> PortProxyBase:
"""Create a proxy class for generators to use in order to access module
Expand Down Expand Up @@ -475,6 +479,17 @@ def create_op(self, sys, symbol):
else:
self.add_metadata(sys, symbol, None)

# If there are associated constants, add them to the manifest.
if len(self.constants) > 0:
constants_dict: Dict[str, ir.Attribute] = {}
for name, constant in self.constants.items():
constant_attr = obj_to_typed_attribute(constant.value, constant.type)
constants_dict[name] = constant_attr
with ir.InsertionPoint(sys.mod.body):
from .dialects.esi import esi
esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol),
constants=ir.DictAttr.get(constants_dict))

if len(self.generators) > 0:
if hasattr(self, "parameters") and self.parameters is not None:
self.attributes["pycde.parameters"] = self.parameters
Expand Down
13 changes: 13 additions & 0 deletions frontends/PyCDE/src/pycde/support.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from __future__ import annotations

from .circt import support
from .circt import ir

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .types import Type

import os


Expand Down Expand Up @@ -43,6 +49,13 @@ def _obj_to_attribute(obj) -> ir.Attribute:
"This is required for parameters.")


def obj_to_typed_attribute(obj: object, type: Type) -> ir.Attribute:
from .types import BitVectorType
if isinstance(type, BitVectorType):
return ir.IntegerAttr.get(type._type, obj)
raise ValueError(f"Type '{type}' conversion to attribute not supported yet.")


__dir__ = os.path.dirname(__file__)
_local_files = set([os.path.join(__dir__, x) for x in os.listdir(__dir__)])
_hidden_filenames = set(["functools.py"])
Expand Down
5 changes: 4 additions & 1 deletion frontends/PyCDE/test/test_esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pycde import (Clock, Input, InputChannel, Output, OutputChannel, Module,
Reset, generator, types)
from pycde import esi
from pycde.common import AppID, RecvBundle, SendBundle
from pycde.common import AppID, Constant, RecvBundle, SendBundle
from pycde.constructs import Wire
from pycde.esi import MMIO
from pycde.module import Metadata
Expand Down Expand Up @@ -36,6 +36,7 @@ class HostComms:


# CHECK: esi.manifest.sym @LoopbackInOutTop name "LoopbackInOut" {{.*}}version "0.1" {bar = "baz", foo = 1 : i64}
# CHECK: esi.manifest.constants @LoopbackInOutTop {c1 = 54 : ui8}


# CHECK-LABEL: hw.module @LoopbackInOutTop(in %clk : !seq.clock, in %rst : i1)
Expand All @@ -59,6 +60,8 @@ class LoopbackInOutTop(Module):
},
)

c1 = Constant(UInt(8), 54)

@generator
def construct(self):
# Use Cosim to implement the 'HostComms' service.
Expand Down
Loading