Started classifying on-screen objects with unified "analyse_frame" function
Rocks, lives, and missiles are identified by their rectangles from cv2.matchTemplate. Clusters from SIFT are checked against these detected rectangles. The remaining objects (like the rotated ship) are classified as "mysteries". These will be target for further analysis (like ship angle determination), or we can just shoot at them :)
This commit is contained in:
parent
372f250167
commit
5499aa5e49
95
gamemodel.py
95
gamemodel.py
|
@ -38,6 +38,7 @@ class GameModel:
|
|||
return inner
|
||||
|
||||
def clear_frame(self):
|
||||
self.prev_frame = frame
|
||||
self.frame = None
|
||||
|
||||
@with_frame
|
||||
|
@ -60,7 +61,7 @@ class GameModel:
|
|||
color = { "big": (255, 0, 0),
|
||||
"normal": (0, 255, 0),
|
||||
"small": (0, 0, 255),
|
||||
"missile": (128, 0, 0),
|
||||
"missile": (0, 255, 128),
|
||||
"ship_on": (0, 0, 128),
|
||||
"ship_off": (0, 64, 128)}[label]
|
||||
cv2.rectangle(displayable, pt, wh, color, 1)
|
||||
|
@ -97,7 +98,7 @@ class GameModel:
|
|||
## #return { "matchsets": matchsets,
|
||||
## # "kp_desc": kp_desc
|
||||
## # }
|
||||
ship_rsq = rect_radius_squared(*self.ships[0][1].shape)
|
||||
ship_rsq = rect_radius_squared(*self.ships[0][1].shape) * 0.85
|
||||
#print(f"max radius^2: {ship_rsq}")
|
||||
clusters = pointcluster.cluster_set([k.pt for k in frame_kp], sqrt(ship_rsq))
|
||||
|
||||
|
@ -115,18 +116,76 @@ class GameModel:
|
|||
ship_rects.append((pt, (pt[0] + w, pt[1] + h), label))
|
||||
return ship_rects
|
||||
|
||||
## @with_frame
|
||||
## def find_missiles(self):
|
||||
## """This technique does not work for the 9x9 pixel missile image."""
|
||||
## missile_rects = []
|
||||
## label, img = self.missile
|
||||
## h, w = img.shape
|
||||
## res = cv2.matchTemplate(self.frame, img, cv2.TM_CCOEFF_NORMED)
|
||||
## loc = np.where( res >= self.cv_template_thresh)
|
||||
## for pt in zip(*loc[::-1]):
|
||||
## if not missile_rects or squared_distance(missile_rects[-1][0], pt) > self.duplicate_dist_thresh:
|
||||
## missile_rects.append((pt, (pt[0] + w, pt[1] + h), label))
|
||||
## return missile_rects
|
||||
@with_frame
|
||||
def find_missiles(self):
|
||||
# Setup SimpleBlobDetector parameters.
|
||||
params = cv2.SimpleBlobDetector_Params()
|
||||
|
||||
# Change thresholds
|
||||
params.minThreshold = 10;
|
||||
params.maxThreshold = 200;
|
||||
|
||||
# Filter by Area.
|
||||
params.filterByArea = True
|
||||
#params.minArea = 1500
|
||||
params.maxArea = 100
|
||||
|
||||
# Filter by Circularity
|
||||
#params.filterByCircularity = True
|
||||
#params.minCircularity = 0.1
|
||||
|
||||
# Filter by Convexity
|
||||
params.filterByConvexity = True
|
||||
params.minConvexity = 0.95
|
||||
|
||||
# Filter by Inertia
|
||||
params.filterByInertia = True
|
||||
params.minInertiaRatio = 0.4
|
||||
|
||||
detector = cv2.SimpleBlobDetector_create(params)
|
||||
keypoints = detector.detect(cv2.bitwise_not(self.frame)) # inverted black/white frame
|
||||
#im_with_keypoints = cv2.drawKeypoints(self.frame, keypoints, np.array([]),
|
||||
# (0,0,255), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
|
||||
#cv2.imshow("keypoints", im_with_keypoints)
|
||||
#cv2.waitKey(0)
|
||||
s = 9 # pixels for the missile
|
||||
rect_tuple = lambda pt: ((int(pt[0]-s/2),int(pt[1]-s/2)),
|
||||
(int(pt[0]+s/2), int(pt[1]+s/2)),
|
||||
"missile")
|
||||
return [rect_tuple(k.pt) for k in keypoints]
|
||||
|
||||
def analyse_frame(self):
|
||||
rocks = self.find_asteroids()
|
||||
lives = self.find_ships()
|
||||
shots = self.find_missiles()
|
||||
clusters = self.frame_sift()
|
||||
|
||||
labeled_objects = rocks + lives + shots
|
||||
mystery_clusters = []
|
||||
# TODO: remove these comprehensions and document pretty utility functions.
|
||||
easy_find = lambda cluster: any(
|
||||
[cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1])
|
||||
and point_in_rect(cluster.center, (lo[0], lo[1]))
|
||||
for lo in labeled_objects])
|
||||
hard_find = lambda cluster: any(
|
||||
[cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1])
|
||||
and all([point_in_rect(p, (lo[0], lo[1]))
|
||||
for p in cluster.points])
|
||||
for lo in labeled_objects])
|
||||
# Allow me to explain/apologize.
|
||||
## The first term (cluster.max_distance < ...) stops big point clusters from
|
||||
## being regarded as smalll objects. (Player ship being matched "inside" a missile)
|
||||
|
||||
## The second term (point_in_rect(...)) checks for a "cluster" inside a "rect".
|
||||
## easy_find just checks the center.
|
||||
## hard_find checks every point, in case the center is off.
|
||||
|
||||
for i, c in enumerate(clusters):
|
||||
#if easy_find(c): continue
|
||||
if hard_find(c): continue
|
||||
mystery_clusters.append(c)
|
||||
r_circles = [(c.center, c.max_distance or 5, f"mystery_{i}") for i, c in enumerate(mystery_clusters)]
|
||||
gm.display_results(rects=labeled_objects, circles=r_circles)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import platform
|
||||
|
@ -147,6 +206,8 @@ if __name__ == '__main__':
|
|||
import pyscreeze
|
||||
io.loc = pyscreeze.Box(0, 25, 800, 599)
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
#input("Press <enter> to detect asteroids on screen.")
|
||||
a_results = gm.find_asteroids()
|
||||
print(f"Found {len(a_results)} asteroids")
|
||||
|
@ -158,4 +219,8 @@ if __name__ == '__main__':
|
|||
polygons = [c.points for c in s_results]
|
||||
#circles = [(c.center, c.max_distance, f"cluster_{i}") for i, c in enumerate(s_results)]
|
||||
r_circles = [(c.center, sqrt(rect_radius_squared(*gm.ships[0][1].shape)), f"cluster_{i}") for i, c in enumerate(s_results)]
|
||||
gm.display_results(rects=a_results+ship_results, pointsets=polygons, circles=r_circles)
|
||||
missile_results = gm.find_missiles()
|
||||
#m_circles = [(pt, 10, f"missile_{i}") for i, pt in enumerate(missiles)]
|
||||
#pprint(a_results+ship_results+missile_results)
|
||||
gm.display_results(rects=a_results+ship_results+missile_results, pointsets=polygons, circles=r_circles)
|
||||
gm.analyse_frame()
|
||||
|
|
10
utility.py
10
utility.py
|
@ -7,3 +7,13 @@ def squared_distance(vec1, vec2):
|
|||
def rect_radius_squared(w, h):
|
||||
"""Returns the radius^2 of the circle inscribed in a rectangle of w * h"""
|
||||
return (w/2)**2 + (h/2)**2
|
||||
|
||||
def point_in_rect(pt, rect):
|
||||
"""Returns True if the (x,y) point is within the ((x,y),(w,h)) rectangle."""
|
||||
px, py = pt
|
||||
tl, wh = rect
|
||||
rx, ry = tl
|
||||
rw, rh = wh
|
||||
rx2 = rx + rw
|
||||
ry2 = ry + rh
|
||||
return all([px >= rx, py >= ry, px <= rx2, py <= ry2])
|
||||
|
|
Loading…
Reference in New Issue