Is there a way to ensure that all my ctypes have argtypes? - ctypes

I know I should specify argtypes for my C/C++ functions since some of my calls would otherwise result in stack corruption.
myCfunc.argtypes = [ct.c_void_p, ct.POINTER(ct.c_void_p)]
myCfunc.errcheck = my_error_check
In fact, I would like to verify that I did not forget to specify function prototypes (argtypes/errcheck) for any of my about 100 function calls...
Right now I just grep through my Python files and visually compare against my file containing the prototype definitions.
Is there a better way to verify that I have defined argtypes/errcheck for all my calls?

The mention of namespaces by #eryksun made me wrap the dll in a class that only exposes the explicitly annotated functions. As long as the dll doesn't have the function names "annotate" or "_error_check" (which my didn't), the following approach seems to work for me:
import ctypes as ct
class MyWinDll:
def __init__(self, dll_filename):
self._dll = ct.WinDLL(dll_filename)
# Specify function prototypes using the annotate function
self.annotate(self._dll.myCfunc, [ct.POINTER(ct.c_void_p)], self._error_check)
self.annotate(self._dll.myCfunc2, [ct.c_void_p], self._error_check)
...
def annotate(self, function, argtypes, errcheck):
# note that "annotate" may not be used as a function name in the dll...
function.argtypes = argtypes
function.errcheck = errcheck
setattr(self, function.__name__, function)
def _error_check(self, result, func, arguments):
if result != 0:
raise Exception
if __name__ == '__main__':
dll = MyWinDll('myWinDll.dll')
handle = ct.c_void_p(None)
# Now call the dll functions using the wrapper object
dll.myCfunc(ct.byref(handle))
dll.myCfunc2(handle)
Update: Comments by #eryksun made me try to improve the code by giving the user control of the WinDLL constructor and attempting to reduce repeated code:
import ctypes as ct
DEFAULT = object()
def annotate(dll_object, function_name, argtypes, restype=DEFAULT, errcheck=DEFAULT):
function = getattr(dll_object._dll, function_name)
function.argtypes = argtypes
# restype and errcheck is optional in the function_prototypes list
if restype is DEFAULT:
restype = dll_object.default_restype
function.restype = restype
if errcheck is DEFAULT:
errcheck = dll_object.default_errcheck
function.errcheck = errcheck
setattr(dll_object, function_name, function)
class MyDll:
def __init__(self, ct_dll, **function_prototypes):
self._dll = ct_dll
for name, prototype in function_prototypes.items():
annotate(self, name, *prototype)
class OneDll(MyDll):
def __init__(self, ct_dll):
# set default values for function_prototypes
self.default_restype = ct.c_int
self.default_errcheck = self._error_check
function_prototypes = {
'myCfunc': [[ct.POINTER(ct.c_void_p)]],
'myCfunc2': [[ct.c_void_p]],
# ...
'myCgetErrTxt': [[ct.c_int, ct.c_char_p, ct.c_size_t], DEFAULT, None]
}
super().__init__(ct_dll, **function_prototypes)
# My error check function actually calls the dll, so I keep it here...
def _error_check(self, result, func, arguments):
msg = ct.create_string_buffer(255)
if result != 0:
raise Exception(self.myCgetErrTxt(result, msg, ct.sizeof(msg)))
if __name__ == '__main__':
ct_dll = ct.WinDLL('myWinDll.dll')
dll = OneDll(ct_dll)
handle = ct.c_void_p(None)
dll.myCfunc(ct.byref(handle))
dll.myCfunc2(handle)
(I don't know if original code should be deleted, I kept it for reference.)

Here's a dummy class that can replace the DLL object's function call with a simple check to see the attributes have been defined:
class DummyFuncPtr(object):
restype = False
argtypes = False
errcheck = False
def __call__(self, *args, **kwargs):
assert self.restype
assert self.argtypes
assert self.errcheck
def __init__(self, *args):
pass
def __setattr__(self, key, value):
super(DummyFuncPtr, self).__setattr__(key, True)
To use it replace your DLL object's _FuncPtr class and then call each function to run the check, e.g.:
dll = ctypes.cdll.LoadLibrary(r'path/to/dll')
# replace the DLL's function pointer
# comment out this line to disable the dummy class
dll._FuncPtr = DummyFuncPtr
some_func = dll.someFunc
some_func.restype = None
some_func.argtypes = None
some_func.errcheck = None
another_func = dll.anotherFunc
another_func.restype = None
another_func.argtypes = None
some_func() # no error
another_func() # Assertion error due to errcheck not defined
The dummy class completely prevents the function from ever being called of course, so just comment out the replacement line to switch back to normal operation.
Note that it will only check each function when that function is called, so this would best be in a unit test file somewhere where the function is guaranteed to be called.

Related

Problem with PettingZoo and Stable-Baselines3 with a ParallelEnv

I am having trouble in making things work with a Custom ParallelEnv I wrote by using PettingZoo. I am using SuperSuit's ss.pettingzoo_env_to_vec_env_v1(env) as a wrapper to Vectorize the environment and make it work with Stable-Baseline3 and documented here.
You can find attached a summary of the most relevant part of the code:
from typing import Optional
from gym import spaces
import random
import numpy as np
from pettingzoo import ParallelEnv
from pettingzoo.utils.conversions import parallel_wrapper_fn
import supersuit as ss
from gym.utils import EzPickle, seeding
def env(**kwargs):
env_ = parallel_env(**kwargs)
env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
#env_ = ss.concat_vec_envs_v1(env_, 1)
return env_
petting_zoo = env
class parallel_env(ParallelEnv, EzPickle):
metadata = {'render_modes': ['ansi'], "name": "PlayerEnv-Multi-v0"}
def __init__(self, n_agents: int = 20, new_step_api: bool = True) -> None:
EzPickle.__init__(
self,
n_agents,
new_step_api
)
self._episode_ended = False
self.n_agents = n_agents
self.possible_agents = [
f"player_{idx}" for idx in range(n_agents)]
self.agents = self.possible_agents[:]
self.agent_name_mapping = dict(
zip(self.possible_agents, list(range(len(self.possible_agents))))
)
self.observation_spaces = spaces.Dict(
{agent: spaces.Box(shape=(len(self.agents),),
dtype=np.float64, low=0.0, high=1.0) for agent in self.possible_agents}
)
self.action_spaces = spaces.Dict(
{agent: spaces.Discrete(4) for agent in self.possible_agents}
)
self.current_step = 0
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
def observation_space(self, agent):
return self.observation_spaces[agent]
def action_space(self, agent):
return self.action_spaces[agent]
def __calculate_observation(self, agent_id: int) -> np.ndarray:
return self.observation_space(agent_id).sample()
def __calculate_observations(self) -> np.ndarray:
observations = {
agent: self.__calculate_observation(
agent_id=agent)
for agent in self.agents
}
return observations
def observe(self, agent):
return self.__calculate_observation(agent_id=agent)
def step(self, actions):
if self._episode_ended:
return self.reset()
observations = self.__calculate_observations()
rewards = random.sample(range(100), self.n_agents)
self.current_step += 1
self._episode_ended = self.current_step >= 100
infos = {agent: {} for agent in self.agents}
dones = {agent: self._episode_ended for agent in self.agents}
rewards = {
self.agents[i]: rewards[i]
for i in range(len(self.agents))
}
if self._episode_ended:
self.agents = {} # To satisfy `set(par_env.agents) == live_agents`
return observations, rewards, dones, infos
def reset(self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,):
self.agents = self.possible_agents[:]
self._episode_ended = False
self.current_step = 0
observations = self.__calculate_observations()
return observations
def render(self, mode="human"):
# TODO: IMPLEMENT
print("TO BE IMPLEMENTED")
def close(self):
pass
Unfortunately when I try to test with the following main procedure:
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.env_checker import check_env
from dummy_env import dummy
from pettingzoo.test import parallel_api_test
if __name__ == '__main__':
# Testing the parallel algorithm alone
env_parallel = dummy.parallel_env()
parallel_api_test(env_parallel) # This works!
# Testing the environment with the wrapper
env = dummy.petting_zoo()
# ERROR: AssertionError: The observation returned by the `reset()` method does not match the given observation space
check_env(env)
# Model initialization
model = PPO("MlpPolicy", env, verbose=1)
# ERROR: ValueError: could not broadcast input array from shape (20,20) into shape (20,)
model.learn(total_timesteps=10_000)
I get the following error:
AssertionError: The observation returned by the `reset()` method does not match the given observation space
If I skip check_env() I get the following one:
ValueError: could not broadcast input array from shape (20,20) into shape (20,)
It seems like that ss.pettingzoo_env_to_vec_env_v1(env) is capable of splitting the parallel environment in multiple vectorized ones, but not for the reset() function.
Does anyone know how to fix this problem?
Plese find the Github Repository to reproduce the problem.
You should double check the reset() function in PettingZoo. It will return None instead of an observation like GYM
Thanks to discussion I had in the issue section of the SuperSuit repository, I am able to post the solution to the problem. Thanks to jjshoots!
First of all it is necessary to have the latest SuperSuit version. In order to get that I needed to install Stable-Baseline3 using the instructions here to make it work with gym 0.24+.
After that, taking the code in the question as example, it is necessary to substitute
def env(**kwargs):
env_ = parallel_env(**kwargs)
env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
#env_ = ss.concat_vec_envs_v1(env_, 1)
return env_
with
def env(**kwargs):
env_ = parallel_env(**kwargs)
env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
env_ = ss.concat_vec_envs_v1(env_, 1, base_class="stable_baselines3")
return env_
The outcomes are:
Outcome 1: leaving the line with check_env(env) I got an error AssertionError: Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py
Outcome 2: removing the line with check_env(env), the agent starts training successfully!
In the end, I think that the argument base_class="stable_baselines3" made the difference.
Only the small problem on check_env remains to be reported, but I think it can be considered as trivial if the training works.

Bizarre Environment-dependent Bad Request 400 error

I'm writing a program to convert a repository into a Docker with an API based on some specification files. When I run the app on my Macbook's base environment, the computer-generated API works perfectly with both gunicorn and uwsgi. However, within the miniconda-based docker container, it failed with Bad Request 400: The browser (or proxy) sent a request that this server could not understand. My goal is to eliminate this error. Obviously, this has to do with the versions of some dependency or set of dependencies. Interestingly, the last endpoint in the API, which has a request parser within a namespace with no arguments, works perfectly, unlike the two other endpoints in the default namespace that do have arguments.
The API is built on flask_restx and uses reqparse.
The API code is here:
from flask_restx import Api, Resource, Namespace, reqparse, inputs
import flask
import process
from load_data import store_data
app = flask.Flask("restful_api")
api = Api(app, title="My API", description="This is an extremely useful API for performing tasks you would do with an API.", version="3.14")
data = {}
data.update(store_data())
class DefaultClass():
def __init__(self):
self.data = data
def _replace_get(self, **args):
default_args = {}
args = {**default_args, **args}
return process.replace(**args)
def _find_get(self, **args):
default_args = {"data": self.data["data"]}
args = {**default_args, **args}
return process.find_in_data_string(**args)
def set_up_worker():
global defaultClass
defaultClass = DefaultClass()
set_up_worker()
_replaceGetParser = reqparse.RequestParser()
_replaceGetParser.add_argument("txt",
type=str,
required=True,
help="Text to search ")
_replaceGetParser.add_argument("old",
type=str,
required=True,
help="Substring to replace ")
_replaceGetParser.add_argument("new",
type=str,
required=True,
help="Replacement for old ")
_replaceGetParser.add_argument("irrelevant_parameter",
type=int,
required=False,
default=5,
help="")
_replaceGetParser.add_argument("smart_casing",
type=inputs.boolean,
required=False,
default=True,
help="True if we should infer replacement capitalization from original casing. ")
_replaceGetParser.add_argument("case_sensitive",
type=inputs.boolean,
required=False,
default=True,
help="True if we should only replace case-sensitive matches ")
_findGetParser = reqparse.RequestParser()
_findGetParser.add_argument("window",
type=int,
required=False,
default=5,
help="Number of characters before and after first match to return ")
_findGetParser.add_argument("txt",
type=str,
required=False,
default="quick",
help="Your search term ")
#api.route('/replace', endpoint='replace', methods=['GET'])
#api.doc('defaultClass')
class ReplaceFrontend(Resource):
#api.expect(_replaceGetParser)
def get(self):
args = _replaceGetParser.parse_args()
return defaultClass._replace_get(**args)
#api.route('/find', endpoint='find', methods=['GET'])
#api.doc('defaultClass')
class FindFrontend(Resource):
#api.expect(_findGetParser)
def get(self):
args = _findGetParser.parse_args()
return defaultClass._find_get(**args)
retrievalNamespace = Namespace("retrieval", description="Data retrieval operations")
class RetrievalNamespaceClass():
def __init__(self):
self.data = data
def _retrieval_retrieve_data_get(self, **args):
default_args = {"data": self.data["data"]}
args = {**default_args, **args}
return process.return_data(**args)
def set_up_retrieval_worker():
global retrievalNamespaceClass
retrievalNamespaceClass = RetrievalNamespaceClass()
set_up_retrieval_worker()
_retrieval_retrieve_dataGetParser = reqparse.RequestParser()
#retrievalNamespace.route('/retrieval/retrieve_data', endpoint='retrieval/retrieve_data', methods=['GET'])
#retrievalNamespace.doc('retrievalNamespaceClass')
class Retrieval_retrieve_dataFrontend(Resource):
#retrievalNamespace.expect(_retrieval_retrieve_dataGetParser)
def get(self):
args = _retrieval_retrieve_dataGetParser.parse_args()
return retrievalNamespaceClass._retrieval_retrieve_data_get(**args)
api.add_namespace(retrievalNamespace)
I have had this problem with both pip-installed gunicorn and conda-installed uwsgi. I'm putting the file imported by the API at the end, since I think it is likely irrelevant what the function definitions are.
import numpy as np
import pandas as pd
import re
from subprocess import Popen, PIPE
from flask_restx import abort
def replace(txt: str = '', # apireq
old: str = '', # apireq
new: str = '', # apireq
case_sensitive: bool = True,
smart_casing: bool = True,
irrelevant_parameter: int = 5):
"""
Search and replace within a string, as long as the string and replacement
contain no four letter words.
arguments:
txt: Text to search
old: Substring to replace
new: Replacement for old
case_sensitive: True if we should only replace case-sensitive matches
smart_casing: True if we should infer replacement capitalization
from original casing.
return
return value
"""
four_letter_words = [re.match('[a-zA-Z]{4}$', word).string
for word in ('%s %s' % (txt, new)).split()
if re.match('[a-zA-Z]{4}$', word)]
if four_letter_words:
error_message = ('Server refuses to process four letter word(s) %s'
% ', '.join(four_letter_words[:5])
+ (', etc' if len(four_letter_words) > 5 else ''))
abort(403, custom=error_message)
return_value = {}
if not case_sensitive:
return_value['output'] = txt.replace(old, new)
else:
lowered = txt.replace(old, old.lower())
return_value['output'] = lowered.replace(old.lower(), new)
return return_value
def find_in_data_string(txt: str = "quick", # req
window: int = 5,
data=None): # noapi
"""
Check if there is a match for your search string in our extensive database,
and return the position of the first match with the surrounding text.
arguments:
txt: Your search term
data: The server's text data
window: Number of characters before and after first match to return
"""
return_value = {}
if txt in data:
idx = data.find(txt)
min_idx = max(idx-window, 0)
max_idx = min(idx+len(txt)+window, len(data)-1)
return_value['string_found'] = True
return_value['position'] = idx
return_value['surrounding_string'] = data[min_idx:max_idx]
return_value['surrounding_string_indices'] = [min_idx, max_idx]
else:
return_value = {['string_found']: False}
return return_value
def return_data(data=None): # noapi
"""
Return all the data in our text database.
"""
with Popen(['which', 'aws'], shell=True, stdout=PIPE) as p:
output = p.stdout.read()
try:
assert not output.strip()
except AssertionError:
abort(503, custom='The server is incorrectly configured.')
return_value = {'data': data}
return return_value

Defining module variables from functions

I've been finally getting into Python, and have noticed something strange, that works in Java, but not in Python.
When I type the following:
fn = "" # Local filename storage.
def read(filename):
fn = filename
return open(filename, 'r').read()
My flake8 linter for Atom gives me the following error:
F841 - local variable 'fn' is assigned to but never used.
I'm assuming this means that the variable is being defined on the def level, and not the module level, which I intend on doing. Please correct me if I'm wrong.
I've searched Google, with multiple wordings, but can't seem to word it in a way that the correct results display...
Any ideas on how I can be able to achieve module-level variable definitions from the function-level?
If you want to declare fn as a global variable (module-level), use global statement.
def read(filename):
global fn # <-----
fn = filename
return open(filename, 'r').read()
BTW, ; is optional. Don't use it.
You can set a module level variable from the function by doing:
import sys
def read(filename):
module = sys.modules[__name__]
setattr(module, 'fn', filename)
return open(filename, 'r').read()
However, it's a very strange necessity. Consider to change your architecture.
UPD: Let's consider an example:
# module1
# uncomment it to fix NameError and AttributeError
# some_var = ''
def foo(val):
global some_var
some_var = val
# module2
from module1 import *
print(some_var) # raises NameError: name 'some_var' is not defined
foo('bar')
print(some_var) # still raises NameError: name 'some_var' is not defined
# module3
import module1
print(module1.some_var) # raises AttributeError: 'module' object has no attribute 'some_var'
foo('bar')
print(module1.some_var) # prints 'bar' even without some_var = '' definition in the module1
So, it's not so obvious how global behaves during the import process. I think, that manually doing setattr(module, 'attr_name', value) during the read() call is more clear.

How to count sqlalchemy queries in unit tests

In Django I often assert the number of queries that should be made so that unit tests catch new N+1 query problems
from django import db
from django.conf import settings
settings.DEBUG=True
class SendData(TestCase):
def test_send(self):
db.connection.queries = []
event = Events.objects.all()[1:]
s = str(event) # QuerySet is lazy, force retrieval
self.assertEquals(len(db.connection.queries), 2)
In in SQLAlchemy tracing to STDOUT is enabled by setting the echo flag on
engine
engine.echo=True
What is the best way to write tests that count the number of queries made by SQLAlchemy?
class SendData(TestCase):
def test_send(self):
event = session.query(Events).first()
s = str(event)
self.assertEquals( ... , 2)
I've created a context manager class for this purpose:
class DBStatementCounter(object):
"""
Use as a context manager to count the number of execute()'s performed
against the given sqlalchemy connection.
Usage:
with DBStatementCounter(conn) as ctr:
conn.execute("SELECT 1")
conn.execute("SELECT 1")
assert ctr.get_count() == 2
"""
def __init__(self, conn):
self.conn = conn
self.count = 0
# Will have to rely on this since sqlalchemy 0.8 does not support
# removing event listeners
self.do_count = False
sqlalchemy.event.listen(conn, 'after_execute', self.callback)
def __enter__(self):
self.do_count = True
return self
def __exit__(self, *_):
self.do_count = False
def get_count(self):
return self.count
def callback(self, *_):
if self.do_count:
self.count += 1
Use SQLAlchemy Core Events to log/track queries executed (you can attach it from your unit tests so they don't impact your performance on the actual application:
event.listen(engine, "before_cursor_execute", catch_queries)
Now you write the function catch_queries, where the way depends on how you test. For example, you could define this function in your test statement:
def test_something(self):
stmts = []
def catch_queries(conn, cursor, statement, ...):
stmts.append(statement)
# Now attach it as a listener and work with the collected events after running your test
The above method is just an inspiration. For extended cases you'd probably like to have a global cache of events that you empty after each test. The reason is that prior to 0.9 (current dev) there is no API to remove event listeners. Thus make one global listener that accesses a global list.
what about the approach of using flask_sqlalchemy.get_debug_queries() btw. this is the methodology used by internal of Flask Debug Toolbar check its source
from flask_sqlalchemy import get_debug_queries
def test_list_with_assuring_queries_count(app, client):
with app.app_context():
# here generating some test data
for _ in range(10):
notebook = create_test_scheduled_notebook_based_on_notebook_file(
db.session, owner='testing_user',
schedule={"kind": SCHEDULE_FREQUENCY_DAILY}
)
for _ in range(100):
create_test_scheduled_notebook_run(db.session, notebook_id=notebook.id)
with app.app_context():
# after resetting the context call actual view we want asserNumOfQueries
client.get(url_for('notebooks.personal_notebooks'))
assert len(get_debug_queries()) == 3
keep in mind that for having reset context and count you have to call with app.app_context() before the exact stuff you want to measure.
Slightly modified version of #omar-tarabai's solution that removes the event listener when exiting the context:
from sqlalchemy import event
class QueryCounter(object):
"""Context manager to count SQLALchemy queries."""
def __init__(self, connection):
self.connection = connection.engine
self.count = 0
def __enter__(self):
event.listen(self.connection, "before_cursor_execute", self.callback)
return self
def __exit__(self, *args, **kwargs):
event.remove(self.connection, "before_cursor_execute", self.callback)
def callback(self, *args, **kwargs):
self.count += 1
Usage:
with QueryCounter(session.connection()) as counter:
session.query(XXX).all()
session.query(YYY).all()
print(counter.count) # 2

how does SQLAlchemy InvalidRequestError('Transaction XXX is not on the active transaction list) happen?

Recently I've got a SQLAlchemy InvalidRequestError.
The error log shows:
InvalidRequestError: Transaction <sqlalchemy.orm.session.SessionTransaction object at
0x106830dd0> is not on the active transaction list
In what circumstance this error will be raised???
-----Edit----
# the following two line actually in my decorator
s = Session()
s.add(model1)
# refer to <http://techspot.zzzeek.org/2012/01/11/django-style-database-routers-in-sqlalchemy/>
s2 = Session().using_bind('master')
model2 = s2.query(Model2).with_lockmode('update').get(1)
model2.somecolumn = 'new'
s2.commit()
This exception is raised
-----Edit2 -----
s = Session().using_bind('master')
model = Model(user_id=123456)
s.add(model)
s.flush()
# here, raise the exception.
# I add log in get_bind() of RoutingSession. when doing 'flush', the _name is None, and it returns engines['slave'].
#If I use commit() instead of flush(), then it commits successfully
I change the using_bind method as the following and it works well.
def using_bind(self, name):
self._name = name
return self
The previous RoutingSession:
class RoutingSession(Session):
_name = None
def get_bind(self, mapper=None, clause=None):
logger.info(self._name)
if self._name:
return engines[self._name]
elif self._flushing:
logger.info('master')
return engines['master']
else:
logger.info('slave')
return engines['slave']
def using_bind(self, name):
s = RoutingSession()
vars(s).update(vars(self))
s._name = name
return s
that's an internal assertion which should never occur. There's no way to answer this question without at least a full stack trace, if perhaps you are improperly using the Session in a concurrent fashion, or manipulating its internals. I can only show that exception raised if I manipulate private methods or state pertaining to the Session object.
Here's that:
from sqlalchemy.orm import Session
s = Session()
s2 = Session()
t = s.transaction
t2 = s2.transaction
s2.transaction = t # nonsensical assignment of the SessionTransaction
# from one Session to also be referred to by another,
# corrupts the transaction chain by leaving out "t2".
# ".transaction" should never be assigned to on the outside
t2.rollback() # triggers the assertion case
basically, the above should never happen, since you're not supposed to assign to ".transaction". that's a read-only attribute.