Started classifying on-screen objects with unified "analyse_frame" function

Rocks, lives, and missiles are identified by their rectangles from cv2.matchTemplate.
Clusters from SIFT are checked against these detected rectangles.
The remaining objects (like the rotated ship) are classified as "mysteries". These will be target for further analysis (like ship angle determination), or we can just shoot at them :)
This commit is contained in:
John McCardle 2021-12-23 13:57:50 -05:00
parent 372f250167
commit 5499aa5e49
2 changed files with 90 additions and 15 deletions

View File

@ -38,6 +38,7 @@ class GameModel:
return inner
def clear_frame(self):
self.prev_frame = frame
self.frame = None
@with_frame
@ -60,7 +61,7 @@ class GameModel:
color = { "big": (255, 0, 0),
"normal": (0, 255, 0),
"small": (0, 0, 255),
"missile": (128, 0, 0),
"missile": (0, 255, 128),
"ship_on": (0, 0, 128),
"ship_off": (0, 64, 128)}[label]
cv2.rectangle(displayable, pt, wh, color, 1)
@ -97,7 +98,7 @@ class GameModel:
## #return { "matchsets": matchsets,
## # "kp_desc": kp_desc
## # }
ship_rsq = rect_radius_squared(*self.ships[0][1].shape)
ship_rsq = rect_radius_squared(*self.ships[0][1].shape) * 0.85
#print(f"max radius^2: {ship_rsq}")
clusters = pointcluster.cluster_set([k.pt for k in frame_kp], sqrt(ship_rsq))
@ -115,18 +116,76 @@ class GameModel:
ship_rects.append((pt, (pt[0] + w, pt[1] + h), label))
return ship_rects
## @with_frame
## def find_missiles(self):
## """This technique does not work for the 9x9 pixel missile image."""
## missile_rects = []
## label, img = self.missile
## h, w = img.shape
## res = cv2.matchTemplate(self.frame, img, cv2.TM_CCOEFF_NORMED)
## loc = np.where( res >= self.cv_template_thresh)
## for pt in zip(*loc[::-1]):
## if not missile_rects or squared_distance(missile_rects[-1][0], pt) > self.duplicate_dist_thresh:
## missile_rects.append((pt, (pt[0] + w, pt[1] + h), label))
## return missile_rects
@with_frame
def find_missiles(self):
# Setup SimpleBlobDetector parameters.
params = cv2.SimpleBlobDetector_Params()
# Change thresholds
params.minThreshold = 10;
params.maxThreshold = 200;
# Filter by Area.
params.filterByArea = True
#params.minArea = 1500
params.maxArea = 100
# Filter by Circularity
#params.filterByCircularity = True
#params.minCircularity = 0.1
# Filter by Convexity
params.filterByConvexity = True
params.minConvexity = 0.95
# Filter by Inertia
params.filterByInertia = True
params.minInertiaRatio = 0.4
detector = cv2.SimpleBlobDetector_create(params)
keypoints = detector.detect(cv2.bitwise_not(self.frame)) # inverted black/white frame
#im_with_keypoints = cv2.drawKeypoints(self.frame, keypoints, np.array([]),
# (0,0,255), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
#cv2.imshow("keypoints", im_with_keypoints)
#cv2.waitKey(0)
s = 9 # pixels for the missile
rect_tuple = lambda pt: ((int(pt[0]-s/2),int(pt[1]-s/2)),
(int(pt[0]+s/2), int(pt[1]+s/2)),
"missile")
return [rect_tuple(k.pt) for k in keypoints]
def analyse_frame(self):
rocks = self.find_asteroids()
lives = self.find_ships()
shots = self.find_missiles()
clusters = self.frame_sift()
labeled_objects = rocks + lives + shots
mystery_clusters = []
# TODO: remove these comprehensions and document pretty utility functions.
easy_find = lambda cluster: any(
[cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1])
and point_in_rect(cluster.center, (lo[0], lo[1]))
for lo in labeled_objects])
hard_find = lambda cluster: any(
[cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1])
and all([point_in_rect(p, (lo[0], lo[1]))
for p in cluster.points])
for lo in labeled_objects])
# Allow me to explain/apologize.
## The first term (cluster.max_distance < ...) stops big point clusters from
## being regarded as smalll objects. (Player ship being matched "inside" a missile)
## The second term (point_in_rect(...)) checks for a "cluster" inside a "rect".
## easy_find just checks the center.
## hard_find checks every point, in case the center is off.
for i, c in enumerate(clusters):
#if easy_find(c): continue
if hard_find(c): continue
mystery_clusters.append(c)
r_circles = [(c.center, c.max_distance or 5, f"mystery_{i}") for i, c in enumerate(mystery_clusters)]
gm.display_results(rects=labeled_objects, circles=r_circles)
if __name__ == '__main__':
import platform
@ -147,6 +206,8 @@ if __name__ == '__main__':
import pyscreeze
io.loc = pyscreeze.Box(0, 25, 800, 599)
from pprint import pprint
#input("Press <enter> to detect asteroids on screen.")
a_results = gm.find_asteroids()
print(f"Found {len(a_results)} asteroids")
@ -158,4 +219,8 @@ if __name__ == '__main__':
polygons = [c.points for c in s_results]
#circles = [(c.center, c.max_distance, f"cluster_{i}") for i, c in enumerate(s_results)]
r_circles = [(c.center, sqrt(rect_radius_squared(*gm.ships[0][1].shape)), f"cluster_{i}") for i, c in enumerate(s_results)]
gm.display_results(rects=a_results+ship_results, pointsets=polygons, circles=r_circles)
missile_results = gm.find_missiles()
#m_circles = [(pt, 10, f"missile_{i}") for i, pt in enumerate(missiles)]
#pprint(a_results+ship_results+missile_results)
gm.display_results(rects=a_results+ship_results+missile_results, pointsets=polygons, circles=r_circles)
gm.analyse_frame()

View File

@ -7,3 +7,13 @@ def squared_distance(vec1, vec2):
def rect_radius_squared(w, h):
"""Returns the radius^2 of the circle inscribed in a rectangle of w * h"""
return (w/2)**2 + (h/2)**2
def point_in_rect(pt, rect):
"""Returns True if the (x,y) point is within the ((x,y),(w,h)) rectangle."""
px, py = pt
tl, wh = rect
rx, ry = tl
rw, rh = wh
rx2 = rx + rw
ry2 = ry + rh
return all([px >= rx, py >= ry, px <= rx2, py <= ry2])