From 5499aa5e49cf701e3692f3851fcf1271126e03f7 Mon Sep 17 00:00:00 2001 From: John McCardle Date: Thu, 23 Dec 2021 13:57:50 -0500 Subject: [PATCH] Started classifying on-screen objects with unified "analyse_frame" function Rocks, lives, and missiles are identified by their rectangles from cv2.matchTemplate. Clusters from SIFT are checked against these detected rectangles. The remaining objects (like the rotated ship) are classified as "mysteries". These will be target for further analysis (like ship angle determination), or we can just shoot at them :) --- gamemodel.py | 95 +++++++++++++++++++++++++++++++++++++++++++--------- utility.py | 10 ++++++ 2 files changed, 90 insertions(+), 15 deletions(-) diff --git a/gamemodel.py b/gamemodel.py index fec0ba6..60c556b 100644 --- a/gamemodel.py +++ b/gamemodel.py @@ -38,6 +38,7 @@ class GameModel: return inner def clear_frame(self): + self.prev_frame = frame self.frame = None @with_frame @@ -60,7 +61,7 @@ class GameModel: color = { "big": (255, 0, 0), "normal": (0, 255, 0), "small": (0, 0, 255), - "missile": (128, 0, 0), + "missile": (0, 255, 128), "ship_on": (0, 0, 128), "ship_off": (0, 64, 128)}[label] cv2.rectangle(displayable, pt, wh, color, 1) @@ -97,7 +98,7 @@ class GameModel: ## #return { "matchsets": matchsets, ## # "kp_desc": kp_desc ## # } - ship_rsq = rect_radius_squared(*self.ships[0][1].shape) + ship_rsq = rect_radius_squared(*self.ships[0][1].shape) * 0.85 #print(f"max radius^2: {ship_rsq}") clusters = pointcluster.cluster_set([k.pt for k in frame_kp], sqrt(ship_rsq)) @@ -115,18 +116,76 @@ class GameModel: ship_rects.append((pt, (pt[0] + w, pt[1] + h), label)) return ship_rects -## @with_frame -## def find_missiles(self): -## """This technique does not work for the 9x9 pixel missile image.""" -## missile_rects = [] -## label, img = self.missile -## h, w = img.shape -## res = cv2.matchTemplate(self.frame, img, cv2.TM_CCOEFF_NORMED) -## loc = np.where( res >= self.cv_template_thresh) -## for pt in zip(*loc[::-1]): -## if not missile_rects or squared_distance(missile_rects[-1][0], pt) > self.duplicate_dist_thresh: -## missile_rects.append((pt, (pt[0] + w, pt[1] + h), label)) -## return missile_rects + @with_frame + def find_missiles(self): + # Setup SimpleBlobDetector parameters. + params = cv2.SimpleBlobDetector_Params() + + # Change thresholds + params.minThreshold = 10; + params.maxThreshold = 200; + + # Filter by Area. + params.filterByArea = True + #params.minArea = 1500 + params.maxArea = 100 + + # Filter by Circularity + #params.filterByCircularity = True + #params.minCircularity = 0.1 + + # Filter by Convexity + params.filterByConvexity = True + params.minConvexity = 0.95 + + # Filter by Inertia + params.filterByInertia = True + params.minInertiaRatio = 0.4 + + detector = cv2.SimpleBlobDetector_create(params) + keypoints = detector.detect(cv2.bitwise_not(self.frame)) # inverted black/white frame + #im_with_keypoints = cv2.drawKeypoints(self.frame, keypoints, np.array([]), + # (0,0,255), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) + #cv2.imshow("keypoints", im_with_keypoints) + #cv2.waitKey(0) + s = 9 # pixels for the missile + rect_tuple = lambda pt: ((int(pt[0]-s/2),int(pt[1]-s/2)), + (int(pt[0]+s/2), int(pt[1]+s/2)), + "missile") + return [rect_tuple(k.pt) for k in keypoints] + + def analyse_frame(self): + rocks = self.find_asteroids() + lives = self.find_ships() + shots = self.find_missiles() + clusters = self.frame_sift() + + labeled_objects = rocks + lives + shots + mystery_clusters = [] + # TODO: remove these comprehensions and document pretty utility functions. + easy_find = lambda cluster: any( + [cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1]) + and point_in_rect(cluster.center, (lo[0], lo[1])) + for lo in labeled_objects]) + hard_find = lambda cluster: any( + [cluster.max_distance < max(lo[1][0] - lo[0][0], lo[1][1] - lo[0][1]) + and all([point_in_rect(p, (lo[0], lo[1])) + for p in cluster.points]) + for lo in labeled_objects]) + # Allow me to explain/apologize. + ## The first term (cluster.max_distance < ...) stops big point clusters from + ## being regarded as smalll objects. (Player ship being matched "inside" a missile) + + ## The second term (point_in_rect(...)) checks for a "cluster" inside a "rect". + ## easy_find just checks the center. + ## hard_find checks every point, in case the center is off. + + for i, c in enumerate(clusters): + #if easy_find(c): continue + if hard_find(c): continue + mystery_clusters.append(c) + r_circles = [(c.center, c.max_distance or 5, f"mystery_{i}") for i, c in enumerate(mystery_clusters)] + gm.display_results(rects=labeled_objects, circles=r_circles) if __name__ == '__main__': import platform @@ -147,6 +206,8 @@ if __name__ == '__main__': import pyscreeze io.loc = pyscreeze.Box(0, 25, 800, 599) + from pprint import pprint + #input("Press to detect asteroids on screen.") a_results = gm.find_asteroids() print(f"Found {len(a_results)} asteroids") @@ -158,4 +219,8 @@ if __name__ == '__main__': polygons = [c.points for c in s_results] #circles = [(c.center, c.max_distance, f"cluster_{i}") for i, c in enumerate(s_results)] r_circles = [(c.center, sqrt(rect_radius_squared(*gm.ships[0][1].shape)), f"cluster_{i}") for i, c in enumerate(s_results)] - gm.display_results(rects=a_results+ship_results, pointsets=polygons, circles=r_circles) + missile_results = gm.find_missiles() + #m_circles = [(pt, 10, f"missile_{i}") for i, pt in enumerate(missiles)] + #pprint(a_results+ship_results+missile_results) + gm.display_results(rects=a_results+ship_results+missile_results, pointsets=polygons, circles=r_circles) + gm.analyse_frame() diff --git a/utility.py b/utility.py index 599c38e..2bb9f27 100644 --- a/utility.py +++ b/utility.py @@ -7,3 +7,13 @@ def squared_distance(vec1, vec2): def rect_radius_squared(w, h): """Returns the radius^2 of the circle inscribed in a rectangle of w * h""" return (w/2)**2 + (h/2)**2 + +def point_in_rect(pt, rect): + """Returns True if the (x,y) point is within the ((x,y),(w,h)) rectangle.""" + px, py = pt + tl, wh = rect + rx, ry = tl + rw, rh = wh + rx2 = rx + rw + ry2 = ry + rh + return all([px >= rx, py >= ry, px <= rx2, py <= ry2])